Tidied up code.

Some of the code was nasty. I cleaned it up a bit. I also started moving
tests into a tests/ directory. You can run python -m unittest to run
them. I still need to add more tests to make sure things are working as
expected.
This commit is contained in:
Michael Whittaker 2021-01-30 18:47:25 -08:00
parent 63b88c38d5
commit a42bccaf0c
6 changed files with 521 additions and 354 deletions

View file

@ -1,5 +1,5 @@
from .expr import Node, choose, majority from .expr import Node, choose, majority
from .quorum_system import QuorumSystem from .quorum_system import QuorumSystem, Strategy
from .viz import ( from .viz import (
plot_node_load, plot_node_load,
plot_node_load_on, plot_node_load_on,

View file

@ -1,6 +1,5 @@
from typing import Any, Callable, List, NamedTuple, Optional, Tuple from typing import Any, Callable, List, NamedTuple, Optional, Tuple
import math import math
import unittest
class Point(NamedTuple): class Point(NamedTuple):
@ -90,165 +89,3 @@ def max_of_segments(segments: List[Segment]) -> List[Tuple[float, float]]:
xs.append(p.x) xs.append(p.x)
xs.sort() xs.sort()
return [(x, max(segments, key=lambda s: s(x))(x)) for x in xs] return [(x, max(segments, key=lambda s: s(x))(x)) for x in xs]
class TestGeometry(unittest.TestCase):
def test_eq(self):
l = Point(0, 1)
r = Point(1, 1)
m = Point(0.5, 0.5)
self.assertEqual(Segment(l, r), Segment(l, r))
self.assertNotEqual(Segment(l, r), Segment(l, m))
def test_compatible(self):
s1 = Segment(Point(0, 1), Point(1, 2))
s2 = Segment(Point(0, 2), Point(1, 1))
s3 = Segment(Point(0.5, 2), Point(1, 1))
self.assertTrue(s1.compatible(s2))
self.assertTrue(s2.compatible(s1))
self.assertFalse(s1.compatible(s3))
self.assertFalse(s3.compatible(s1))
self.assertFalse(s2.compatible(s3))
self.assertFalse(s3.compatible(s2))
def test_call(self):
segment = Segment(Point(0, 0), Point(1, 1))
for x in [0.0, 0.25, 0.5, 0.75, 1.0]:
self.assertEqual(segment(x), x)
segment = Segment(Point(0, 0), Point(1, 2))
for x in [0.0, 0.25, 0.5, 0.75, 1.0]:
self.assertEqual(segment(x), 2*x)
segment = Segment(Point(1, 2), Point(3, 6))
for x in [1.0, 1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 2.75, 3.0]:
self.assertEqual(segment(x), 2*x)
segment = Segment(Point(0, 1), Point(1, 0))
for x in [0.0, 0.25, 0.5, 0.75, 1.0]:
self.assertEqual(segment(x), 1 - x)
def test_slope(self):
self.assertEqual(Segment(Point(0, 0), Point(1, 1)).slope(), 1.0)
self.assertEqual(Segment(Point(0, 1), Point(1, 2)).slope(), 1.0)
self.assertEqual(Segment(Point(1, 1), Point(2, 2)).slope(), 1.0)
self.assertEqual(Segment(Point(1, 1), Point(2, 3)).slope(), 2.0)
self.assertEqual(Segment(Point(1, 1), Point(2, 0)).slope(), -1.0)
def test_above(self):
s1 = Segment(Point(0, 0), Point(1, 0.5))
s2 = Segment(Point(0, 0.5), Point(1, 2))
s3 = Segment(Point(0, 1.5), Point(1, 0.5))
self.assertFalse(s1.above(s1))
self.assertFalse(s1.above(s2))
self.assertFalse(s1.above(s3))
self.assertTrue(s2.above(s1))
self.assertFalse(s2.above(s2))
self.assertFalse(s2.above(s3))
self.assertTrue(s3.above(s1))
self.assertFalse(s3.above(s2))
self.assertFalse(s3.above(s3))
def test_above_eq(self):
s1 = Segment(Point(0, 0), Point(1, 0.5))
s2 = Segment(Point(0, 0.5), Point(1, 2))
s3 = Segment(Point(0, 1.5), Point(1, 0.5))
self.assertTrue(s1.above_eq(s1))
self.assertFalse(s1.above_eq(s2))
self.assertFalse(s1.above_eq(s3))
self.assertTrue(s2.above_eq(s1))
self.assertTrue(s2.above_eq(s2))
self.assertFalse(s2.above_eq(s3))
self.assertTrue(s3.above_eq(s1))
self.assertFalse(s3.above_eq(s2))
self.assertTrue(s3.above_eq(s3))
def test_intersects(self):
s1 = Segment(Point(0, 0), Point(1, 0.5))
s2 = Segment(Point(0, 0.5), Point(1, 2))
s3 = Segment(Point(0, 1.5), Point(1, 0.5))
self.assertTrue(s1.intersects(s1))
self.assertFalse(s1.intersects(s2))
self.assertTrue(s1.intersects(s3))
self.assertFalse(s2.intersects(s1))
self.assertTrue(s2.intersects(s2))
self.assertTrue(s2.intersects(s3))
self.assertTrue(s3.intersects(s1))
self.assertTrue(s3.intersects(s2))
self.assertTrue(s3.intersects(s3))
def test_intersection(self):
s1 = Segment(Point(0, 0), Point(1, 1))
s2 = Segment(Point(0, 1), Point(1, 0))
s3 = Segment(Point(0, 1), Point(1, 1))
s4 = Segment(Point(0, 0.25), Point(1, 0.25))
self.assertEqual(s1.intersection(s1), None)
self.assertEqual(s1.intersection(s2), Point(0.5, 0.5))
self.assertEqual(s1.intersection(s3), Point(1, 1))
self.assertEqual(s1.intersection(s4), Point(0.25, 0.25))
self.assertEqual(s2.intersection(s1), Point(0.5, 0.5))
self.assertEqual(s2.intersection(s2), None)
self.assertEqual(s2.intersection(s3), Point(0, 1))
self.assertEqual(s2.intersection(s4), Point(0.75, 0.25))
self.assertEqual(s3.intersection(s1), Point(1, 1))
self.assertEqual(s3.intersection(s2), Point(0, 1))
self.assertEqual(s3.intersection(s3), None)
self.assertEqual(s3.intersection(s4), None)
self.assertEqual(s4.intersection(s1), Point(0.25, 0.25))
self.assertEqual(s4.intersection(s2), Point(0.75, 0.25))
self.assertEqual(s4.intersection(s3), None)
self.assertEqual(s4.intersection(s4), None)
def test_max_one_segment(self):
s1 = Segment(Point(0, 0), Point(1, 1))
s2 = Segment(Point(0, 1), Point(1, 0))
s3 = Segment(Point(0, 1), Point(1, 1))
s4 = Segment(Point(0, 0.25), Point(1, 0.25))
s5 = Segment(Point(0, 0.75), Point(1, 0.75))
def is_subset(xs: List[Any], ys: List[Any]) -> bool:
return all(x in ys for x in xs)
for s in [s1, s2, s3, s4, s5]:
self.assertEqual(max_of_segments([s]), [s.l, s.r])
expected = [
([s1, s1], [(0, 0), (1, 1)]),
([s1, s2], [(0, 1), (0.5, 0.5), (1, 1)]),
([s1, s3], [(0, 1), (1, 1)]),
([s1, s4], [(0, 0.25), (0.25, 0.25), (1, 1)]),
([s1, s5], [(0, 0.75), (0.75, 0.75), (1, 1)]),
([s2, s2], [(0, 1), (1, 0)]),
([s2, s3], [(0, 1), (1, 1)]),
([s2, s4], [(0, 1), (0.75, 0.25), (1, 0.25)]),
([s2, s5], [(0, 1), (0.25, 0.75), (1, 0.75)]),
([s3, s3], [(0, 1), (1, 1)]),
([s3, s4], [(0, 1), (1, 1)]),
([s3, s5], [(0, 1), (1, 1)]),
([s4, s4], [(0, 0.25), (1, 0.25)]),
([s4, s5], [(0, 0.75), (1, 0.75)]),
([s5, s5], [(0, 0.75), (1, 0.75)]),
([s1, s2, s4], [(0, 1), (0.5, 0.5), (1, 1)]),
([s1, s2, s5], [(0, 1), (0.25, 0.75), (0.75, 0.75), (1, 1)]),
]
for segments, path in expected:
self.assertTrue(is_subset(path, max_of_segments(segments)))
self.assertTrue(is_subset(path, max_of_segments(segments[::-1])))
if __name__ == '__main__':
unittest.main()

View file

@ -1,6 +1,3 @@
# TODO(mwhittaker): We can define a set of read quorums that are not minimal.
# Does this mess things up?
from . import distribution from . import distribution
from . import geometry from . import geometry
from .distribution import Distribution from .distribution import Distribution
@ -16,7 +13,6 @@ import pulp
T = TypeVar('T') T = TypeVar('T')
LOAD = 'load' LOAD = 'load'
NETWORK = 'network' NETWORK = 'network'
LATENCY = 'latency' LATENCY = 'latency'
@ -29,8 +25,7 @@ class QuorumSystem(Generic[T]):
writes: Optional[Expr[T]] = None) -> None: writes: Optional[Expr[T]] = None) -> None:
if reads is not None and writes is not None: if reads is not None and writes is not None:
optimal_writes = reads.dual() optimal_writes = reads.dual()
if not all(optimal_writes.is_quorum(write_quorum) if not all(optimal_writes.is_quorum(wq) for wq in writes.quorums()):
for write_quorum in writes.quorums()):
raise ValueError( raise ValueError(
'Not all read quorums intersect all write quorums') 'Not all read quorums intersect all write quorums')
@ -63,9 +58,15 @@ class QuorumSystem(Generic[T]):
def is_write_quorum(self, xs: Set[T]) -> bool: def is_write_quorum(self, xs: Set[T]) -> bool:
return self.writes.is_quorum(xs) return self.writes.is_quorum(xs)
def node(self, x: T) -> Node[T]:
return self.x_to_node[x]
def nodes(self) -> Set[Node[T]]: def nodes(self) -> Set[Node[T]]:
return self.reads.nodes() | self.writes.nodes() return self.reads.nodes() | self.writes.nodes()
def elements(self) -> Set[T]:
return {node.x for node in self.nodes()}
def resilience(self) -> int: def resilience(self) -> int:
return min(self.read_resilience(), self.write_resilience()) return min(self.read_resilience(), self.write_resilience())
@ -75,6 +76,107 @@ class QuorumSystem(Generic[T]):
def write_resilience(self) -> int: def write_resilience(self) -> int:
return self.writes.resilience() return self.writes.resilience()
def dup_free(self) -> bool:
return self.reads.dup_free() and self.writes.dup_free()
def load(self,
optimize: str = LOAD,
load_limit: Optional[float] = None,
network_limit: Optional[float] = None,
latency_limit: Optional[datetime.timedelta] = None,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None,
f: int = 0) -> float:
return self.strategy(
optimize,
load_limit,
network_limit,
latency_limit,
read_fraction,
write_fraction,
f
).load(read_fraction, write_fraction)
def capacity(self,
optimize: str = LOAD,
load_limit: Optional[float] = None,
network_limit: Optional[float] = None,
latency_limit: Optional[datetime.timedelta] = None,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None,
f: int = 0) -> float:
return self.strategy(
optimize,
load_limit,
network_limit,
latency_limit,
read_fraction,
write_fraction,
f
).capacity(read_fraction, write_fraction)
def network_load(self,
optimize: str = LOAD,
load_limit: Optional[float] = None,
network_limit: Optional[float] = None,
latency_limit: Optional[datetime.timedelta] = None,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None,
f: int = 0) -> float:
return self.strategy(
optimize,
load_limit,
network_limit,
latency_limit,
read_fraction,
write_fraction,
f
).network_load(read_fraction, write_fraction)
def latency(self,
optimize: str = LOAD,
load_limit: Optional[float] = None,
network_limit: Optional[float] = None,
latency_limit: Optional[datetime.timedelta] = None,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None,
f: int = 0) -> float:
return self.strategy(
optimize,
load_limit,
network_limit,
latency_limit,
read_fraction,
write_fraction,
f
).latency(read_fraction, write_fraction)
def uniform_strategy(self, f: int = 0) -> 'Strategy[T]':
"""
uniform_strategy(f) returns a uniform strategy over the minimal
f-resilient quorums. That is, every minimal f-resilient quorum is
equally likely to be chosen.
"""
if f < 0:
raise ValueError('f must be >= 0')
elif f == 0:
read_quorums = list(self.read_quorums())
write_quorums = list(self.write_quorums())
else:
xs = list(self.elements())
read_quorums = list(self._f_resilient_quorums(f, xs, self.reads))
write_quorums = list(self._f_resilient_quorums(f, xs, self.reads))
if len(read_quorums) == 0:
raise ValueError(f'There are no {f}-resilient read quorums')
if len(write_quorums) == 0:
raise ValueError(f'There are no {f}-resilient write quorums')
read_quorums = self._minimize(read_quorums)
write_quorums = self._minimize(write_quorums)
sigma_r = {frozenset(rq): 1 / len(rq) for rq in read_quorums}
sigma_w = {frozenset(wq): 1 / len(wq) for wq in write_quorums}
return Strategy(self, sigma_r, sigma_w)
def strategy(self, def strategy(self,
optimize: str = LOAD, optimize: str = LOAD,
load_limit: Optional[float] = None, load_limit: Optional[float] = None,
@ -82,10 +184,10 @@ class QuorumSystem(Generic[T]):
latency_limit: Optional[datetime.timedelta] = None, latency_limit: Optional[datetime.timedelta] = None,
read_fraction: Optional[Distribution] = None, read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None, write_fraction: Optional[Distribution] = None,
f: int = 0) \ f: int = 0) -> 'Strategy[T]':
-> 'Strategy[T]': if optimize not in {LOAD, NETWORK, LATENCY}:
if f < 0: raise ValueError(
raise ValueError('f must be >= 0') f'optimize must be one of {LOAD}, {NETWORK}, or {LATENCY}')
if optimize == LOAD and load_limit is not None: if optimize == LOAD and load_limit is not None:
raise ValueError( raise ValueError(
@ -99,6 +201,9 @@ class QuorumSystem(Generic[T]):
raise ValueError( raise ValueError(
'a latency limit cannot be set when optimizing for latency') 'a latency limit cannot be set when optimizing for latency')
if f < 0:
raise ValueError('f must be >= 0')
d = distribution.canonicalize_rw(read_fraction, write_fraction) d = distribution.canonicalize_rw(read_fraction, write_fraction)
if f == 0: if f == 0:
return self._load_optimal_strategy( return self._load_optimal_strategy(
@ -110,7 +215,7 @@ class QuorumSystem(Generic[T]):
network_limit=network_limit, network_limit=network_limit,
latency_limit=latency_limit) latency_limit=latency_limit)
else: else:
xs = [node.x for node in self.nodes()] xs = list(self.elements())
read_quorums = list(self._f_resilient_quorums(f, xs, self.reads)) read_quorums = list(self._f_resilient_quorums(f, xs, self.reads))
write_quorums = list(self._f_resilient_quorums(f, xs, self.reads)) write_quorums = list(self._f_resilient_quorums(f, xs, self.reads))
if len(read_quorums) == 0: if len(read_quorums) == 0:
@ -126,13 +231,23 @@ class QuorumSystem(Generic[T]):
network_limit=network_limit, network_limit=network_limit,
latency_limit=latency_limit) latency_limit=latency_limit)
def dup_free(self) -> bool: def _minimize(self, sets: List[Set[T]]) -> List[Set[T]]:
return self.reads.dup_free() and self.writes.dup_free() sets = sorted(sets, key=lambda s: len(s))
minimal_elements: List[Set[T]] = []
for x in sets:
if not any(x >= y for y in minimal_elements):
minimal_elements.append(x)
return minimal_elements
def _f_resilient_quorums(self, def _f_resilient_quorums(self,
f: int, f: int,
xs: List[T], xs: List[T],
e: Expr) -> Iterator[Set[T]]: e: Expr) -> Iterator[Set[T]]:
"""
Consider a set X of elements in xs. We say X is f-resilient if, despite
removing an arbitrary set of f elements from X, X is a quorum in e.
_f_resilient_quorums returns the set of all f-resilient quorums.
"""
assert f >= 1 assert f >= 1
def helper(s: Set[T], i: int) -> Iterator[Set[T]]: def helper(s: Set[T], i: int) -> Iterator[Set[T]]:
@ -148,25 +263,6 @@ class QuorumSystem(Generic[T]):
return helper(set(), 0) return helper(set(), 0)
def load(self,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None,
f: int = 0) \
-> float:
return 0
# TODO(mwhittaker): Remove.
# sigma = self.strategy(read_fraction, write_fraction, f)
# return sigma.load(read_fraction, write_fraction)
def capacity(self,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None,
f: int = 0) \
-> float:
return 0
# TODO(mwhittaker): Remove.
# return 1 / self.load(read_fraction, write_fraction, f)
def _read_quorum_latency(self, quorum: Set[Node[T]]) -> datetime.timedelta: def _read_quorum_latency(self, quorum: Set[Node[T]]) -> datetime.timedelta:
return self._quorum_latency(quorum, self.is_read_quorum) return self._quorum_latency(quorum, self.is_read_quorum)
@ -205,54 +301,132 @@ class QuorumSystem(Generic[T]):
read_quorums = [{a, b}, {c, d}] read_quorums = [{a, b}, {c, d}]
write_quorums = [{a, c}, {a, d}, {b, c}, {b, d}] write_quorums = [{a, c}, {a, d}, {b, c}, {b, d}]
We can form a linear program to compute the optimal load of this quorum We want to find the strategy that is optimal with respect to load,
system for some fixed read fraction fr as follows. First, we create a network load, or latency that satisfies the provided load, network
variable ri for every read quorum i and a variable wi for every write load, or latency constraints.
quorum i. ri represents the probabilty of selecting the ith read
quorum, and wi represents the probabilty of selecting the ith write
quorum. We introduce an additional variable l that represents the load
and solve the following linear program.
min L subject to We can find the optimal strategy using linear programming. First, we
r0 + r1 + r2 = 1 create a variable ri for every read quorum i and a variable wi for
w0 + w1 = 1 every write quorum i. ri represents the probabilty of selecting the ith
fr (r0) + (1 - fr) (w0 + w1) <= L # a's load read quorum, and wi represents the probabilty of selecting the ith
fr (r0) + (1 - fr) (w2 + w3) <= L # b's load write quorum.
fr (r1) + (1 - fr) (w0 + w2) <= L # c's load
fr (r1) + (1 - fr) (w1 + w3) <= L # d's load
If we assume every element x has read capacity rcap_x and write We now explain how to represent load, network load, and latency as
capacity wcap_x, then we adjust the linear program like this. linear expressions.
min L subject to Load
r0 + r1 + r2 = 1 ====
w0 + w1 = 1 Assume a read fraction fr and write fraction fw. The load of a node a is
fr/rcap_a (r0) + (1 - fr)/wcap_a (w0 + w1) <= L # a's load
fr/rcap_b (r0) + (1 - fr)/wcap_b (w2 + w3) <= L # b's load
fr/rcap_c (r1) + (1 - fr)/wcap_c (w0 + w2) <= L # c's load
fr/rcap_d (r1) + (1 - fr)/wcap_d (w1 + w3) <= L # d's load
Assume we have fr = 0.9 with 80% probabilty and fr = 0.5 with 20%. Then load(a) = (fr * rprob(a) / rcap(a)) + (fw * wprob(a) / wcap(a))
we adjust the linear program as follows to find the strategy that
minimzes the average load.
min 0.8 * L_0.9 + 0.2 * L_0.5 subject to where prob_r(a) and prob_w(a) are the probabilities that a is selected
r0 + r1 + r2 = 1 as part of a read or write quorum respectively; and rcap(a) and wcap(a)
w0 + w1 = 1 are the read and write capacities of a. We can express prob_r(a) and
0.9/rcap_a (r0) + 0.1/wcap_a (w0 + w1) <= L_0.9 # a's load prob_w(a) as follows:
0.9/rcap_b (r0) + 0.1/wcap_b (w2 + w3) <= L_0.9 # b's load
0.9/rcap_c (r1) + 0.1/wcap_c (w0 + w2) <= L_0.9 # c's load rprob(a) = sum({ri | a is in read quorum i})
0.9/rcap_d (r1) + 0.1/wcap_d (w1 + w3) <= L_0.9 # d's load wprob(a) = sum({wi | a is in write quorum i})
0.5/rcap_a (r0) + 0.5/wcap_a (w0 + w1) <= L_0.5 # a's load
0.5/rcap_b (r0) + 0.5/wcap_b (w2 + w3) <= L_0.5 # b's load Using the example grid quorum above, we have:
0.5/rcap_c (r1) + 0.5/wcap_c (w0 + w2) <= L_0.5 # c's load
0.5/rcap_d (r1) + 0.5/wcap_d (w1 + w3) <= L_0.5 # d's load rprob(a) = r0 wprob(a) = w0 + w1
rprob(b) = r0 wprob(b) = w2 + w3
rprob(c) = r1 wprob(c) = w0 + w2
rprob(d) = r1 wprob(d) = w1 + w3
The load of a strategy is the maximum load on any node. We can compute
this by minimizing a new variable l and constraining the load of every
node to be less than l. Using the example above, we have
min l subject to
fr * r0 * rcap(a) + fw * (w0 + w1) * wcap(a) <= l
fr * r0 * rcap(b) + fw * (w2 + w3) * wcap(b) <= l
fr * r1 * rcap(c) + fw * (w0 + w2) * wcap(c) <= l
fr * r1 * rcap(d) + fw * (w1 + w3) * wcap(d) <= l
To compute the load of a strategy with respect to a distribution of
read_fractions, we compute the load for every value of fr and weight
according to the distribution. For example, imagine fr is 0.9 80% of
the time and 0.5 20% of the time. We have:
min 0.8 * l0.9 + 0.2 * l0.5
0.9 * r0 * rcap(a) + 0.1 * (w0 + w1) * wcap(a) <= l0.9
0.9 * r0 * rcap(b) + 0.1 * (w2 + w3) * wcap(b) <= l0.9
0.9 * r1 * rcap(c) + 0.1 * (w0 + w2) * wcap(c) <= l0.9
0.9 * r1 * rcap(d) + 0.1 * (w1 + w3) * wcap(d) <= l0.9
0.5 * r0 * rcap(a) + 0.5 * (w0 + w1) * wcap(a) <= l0.5
0.5 * r0 * rcap(b) + 0.5 * (w2 + w3) * wcap(b) <= l0.5
0.5 * r1 * rcap(c) + 0.5 * (w0 + w2) * wcap(c) <= l0.5
0.5 * r1 * rcap(d) + 0.5 * (w1 + w3) * wcap(d) <= l0.5
Let the expression for load be LOAD.
Network
=======
The network load of a strategy is the expected size of a quorum. For a
fixed fr, We can compute the network load as:
fr * sum_i(size(read quorum i) * ri) +
fw * sum_i(size(write quorum i) * ri)
Using the example above:
fr * (2*r0 + 2*r1) + fw * (2*w0 + 2*w1 + 2*w2 + 2*w3)
For a distribution of read fractions, we compute the weighted average.
Let the expression for network load be NETWORK.
Latency
=======
The latency of a strategy is the expected latency of a quorum. We can
compute the latency as:
fr * sum_i(latency(read quorum i) * ri) +
fw * sum_i(latency(write quorum i) * ri)
Using the example above (assuming every node has a latency of 1):
fr * (1*r0 + 1*r1) + fw * (1*w0 + 1*w1 + 1*w2 + 1*w3)
For a distribution of read fractions, we compute the weighted average.
Let the expression for latency be LATENCY.
Linear Program
==============
To find an optimal strategy, we use a linear program. The objective
specified by the user is minimized, and any provided constraints are
added as constraints to the program. For example, imagine the user
wants a load optimal strategy with network load <= 2 and latency <= 3.
We form the program:
min LOAD subject to
sum_i(ri) = 1 # ensure we have a valid distribution on read quorums
sum_i(wi) = 1 # ensure we have a valid distribution on write quorums
NETWORK <= 2
LATENCY <= 3
Using the example above assuming a fixed fr, we have:
min l subject to
fr * r0 * rcap(a) + fw * (w0 + w1) * wcap(a) <= l
fr * r0 * rcap(b) + fw * (w2 + w3) * wcap(b) <= l
fr * r1 * rcap(c) + fw * (w0 + w2) * wcap(c) <= l
fr * r1 * rcap(d) + fw * (w1 + w3) * wcap(d) <= l
fr * (2*r0 + 2*r1) + fw * (2*w0 + 2*w1 + 2*w2 + 2*w3) <= 2
fr * (1*r0 + 1*r1) + fw * (1*w0 + 1*w1 + 1*w2 + 1*w3) <= 3
If we instead wanted to minimize network load with load <= 4 and
latency <= 5, we would have the following program:
min fr * (2*r0 + 2*r1) +
fw * (2*w0 + 2*w1 + 2*w2 + 2*w3) subject to
fr * r0 * rcap(a) + fw * (w0 + w1) * wcap(a) <= 4
fr * r0 * rcap(b) + fw * (w2 + w3) * wcap(b) <= 4
fr * r1 * rcap(c) + fw * (w0 + w2) * wcap(c) <= 4
fr * r1 * rcap(d) + fw * (w1 + w3) * wcap(d) <= 4
fr * (1*r0 + 1*r1) + fw * (1*w0 + 1*w1 + 1*w2 + 1*w3) <= 5
""" """
nodes = self.nodes()
x_to_node = {node.x: node for node in nodes}
read_capacity = {node.x: node.read_capacity for node in nodes}
write_capacity = {node.x: node.write_capacity for node in nodes}
# Create a variable for every read quorum and every write quorum. While # Create a variable for every read quorum and every write quorum. While
# we do this, map each element x to the read and write quorums that # we do this, map each element x to the read and write quorums that
# it's in. For example, image we have the following read and write # it's in. For example, image we have the following read and write
@ -285,46 +459,46 @@ class QuorumSystem(Generic[T]):
for x in write_quorum: for x in write_quorum:
x_to_write_quorum_vars[x].append(v) x_to_write_quorum_vars[x].append(v)
fr = sum(weight * f for (f, weight) in read_fraction.items()) fr = sum(p * fr for (fr, p) in read_fraction.items())
def network() -> pulp.LpAffineExpression: def network() -> pulp.LpAffineExpression:
read_network = fr * sum( reads = fr * sum(
v * len(rq) v * len(rq)
for (rq, v) in zip(read_quorums, read_quorum_vars) for (rq, v) in zip(read_quorums, read_quorum_vars)
) )
write_network = (1 - fr) * sum( writes = (1 - fr) * sum(
v * len(wq) v * len(wq)
for (wq, v) in zip(write_quorums, write_quorum_vars) for (wq, v) in zip(write_quorums, write_quorum_vars)
) )
return read_network + write_network return reads + writes
def latency() -> pulp.LpAffineExpression: def latency() -> pulp.LpAffineExpression:
read_latency = fr * sum( reads = fr * sum(
v * self._read_quorum_latency(quorum).total_seconds() v * self._read_quorum_latency(quorum).total_seconds()
for (rq, v) in zip(read_quorums, read_quorum_vars) for (rq, v) in zip(read_quorums, read_quorum_vars)
for quorum in [{x_to_node[x] for x in rq}] for quorum in [{self.node(x) for x in rq}]
) )
write_latency = (1. - fr) * sum( writes = (1 - fr) * sum(
v * self._write_quorum_latency(quorum).total_seconds() v * self._write_quorum_latency(quorum).total_seconds()
for (wq, v) in zip(write_quorums, write_quorum_vars) for (wq, v) in zip(write_quorums, write_quorum_vars)
for quorum in [{x_to_node[x] for x in wq}] for quorum in [{self.node(x) for x in wq}]
) )
return read_latency + write_latency return reads + writes
def fr_load(problem: pulp.LpProblem, fr: float) -> pulp.LpAffineExpression: def fr_load(problem: pulp.LpProblem, fr: float) -> pulp.LpAffineExpression:
l = pulp.LpVariable(f'l_{fr}', 0, 1) l = pulp.LpVariable(f'l_{fr}', 0, 1)
for node in nodes: for node in self.nodes():
x = node.x x = node.x
x_load: pulp.LpAffineExpression = 0 x_load: pulp.LpAffineExpression = 0
if x in x_to_read_quorum_vars: if x in x_to_read_quorum_vars:
vs = x_to_read_quorum_vars[x] vs = x_to_read_quorum_vars[x]
x_load += fr * sum(vs) / read_capacity[x] x_load += fr * sum(vs) / self.node(x).read_capacity
if x in x_to_write_quorum_vars: if x in x_to_write_quorum_vars:
vs = x_to_write_quorum_vars[x] vs = x_to_write_quorum_vars[x]
x_load += (1 - fr) * sum(vs) / write_capacity[x] x_load += (1 - fr) * sum(vs) / self.node(x).write_capacity
problem += (x_load <= l, f'{x}{fr}') problem += (x_load <= l, f'{x}{fr}')
@ -332,11 +506,11 @@ class QuorumSystem(Generic[T]):
def load(problem: pulp.LpProblem, def load(problem: pulp.LpProblem,
read_fraction: Dict[float, float]) -> pulp.LpAffineExpression: read_fraction: Dict[float, float]) -> pulp.LpAffineExpression:
return sum(weight * fr_load(problem, fr) return sum(p * fr_load(problem, fr)
for (fr, weight) in read_fraction.items()) for (fr, p) in read_fraction.items())
# Form the linear program to find the load. # Form the linear program.
problem = pulp.LpProblem("load", pulp.LpMinimize) problem = pulp.LpProblem("optimal_strategy", pulp.LpMinimize)
# We add these constraints to make sure that the probabilities we # We add these constraints to make sure that the probabilities we
# select form valid probabilty distributions. # select form valid probabilty distributions.
@ -365,152 +539,147 @@ class QuorumSystem(Generic[T]):
'latency limit') 'latency limit')
# Solve the linear program. # Solve the linear program.
print(problem)
problem.solve(pulp.apis.PULP_CBC_CMD(msg=False)) problem.solve(pulp.apis.PULP_CBC_CMD(msg=False))
if problem.status != pulp.LpStatusOptimal: if problem.status != pulp.LpStatusOptimal:
raise ValueError('no strategy satisfies the given constraints') raise ValueError('no strategy satisfies the given constraints')
# Prune out any quorums with 0 probability. # Prune out any quorums with 0 probability.
non_zero_read_quorums = [ sigma_r = {
(rq, v.varValue) frozenset(rq): v.varValue
for (rq, v) in zip(read_quorums, read_quorum_vars) for (rq, v) in zip(read_quorums, read_quorum_vars)
if v.varValue != 0] if v.varValue != 0
non_zero_write_quorums = [ }
(wq, v.varValue) sigma_w = {
frozenset(wq): v.varValue
for (wq, v) in zip(write_quorums, write_quorum_vars) for (wq, v) in zip(write_quorums, write_quorum_vars)
if v.varValue != 0] if v.varValue != 0
return Strategy(self, }
[rq for (rq, _) in non_zero_read_quorums],
[weight for (_, weight) in non_zero_read_quorums], return Strategy(self, sigma_r, sigma_w)
[wq for (wq, _) in non_zero_write_quorums],
[weight for (_, weight) in non_zero_write_quorums])
class Strategy(Generic[T]): class Strategy(Generic[T]):
def __init__(self, def __init__(self,
qs: QuorumSystem[T], qs: QuorumSystem[T],
reads: List[Set[T]], sigma_r: Dict[FrozenSet[T], float],
read_weights: List[float], sigma_w: Dict[FrozenSet[T], float]) -> None:
writes: List[Set[T]],
write_weights: List[float]) -> None:
self.qs = qs self.qs = qs
self.reads = reads self.sigma_r = sigma_r
self.read_weights = read_weights self.sigma_w = sigma_w
self.writes = writes
self.write_weights = write_weights
self.unweighted_read_load: Dict[T, float] = \ # The probability that x is chosen as part of a read quorum.
collections.defaultdict(float) self.x_read_probability: Dict[T, float] = collections.defaultdict(float)
for (read_quorum, weight) in zip(self.reads, self.read_weights): for (read_quorum, p) in self.sigma_r.items():
for x in read_quorum: for x in read_quorum:
self.unweighted_read_load[x] += weight self.x_read_probability[x] += p
self.unweighted_write_load: Dict[T, float] = \ # The probability that x is chosen as part of a write quorum.
collections.defaultdict(float) self.x_write_probability: Dict[T, float] = collections.defaultdict(float)
for (write_quorum, weight) in zip(self.writes, self.write_weights): for (write_quorum, weight) in self.sigma_w.items():
for x in write_quorum: for x in write_quorum:
self.unweighted_write_load[x] += weight self.x_write_probability[x] += weight
@no_type_check
def __str__(self) -> str: def __str__(self) -> str:
non_zero_reads = {tuple(r): p # T may not comparable, so mypy complains about this sort.
for (r, p) in zip(self.reads, self.read_weights) reads = {tuple(sorted(rq)): p for (rq, p) in self.sigma_r.items()}
if p > 0} writes = {tuple(sorted(wq)): p for (wq, p) in self.sigma_w.items()}
non_zero_writes = {tuple(w): p return f'Strategy(reads={reads}, writes={writes})'
for (w, p) in zip(self.writes, self.write_weights)
if p > 0} def quorum_system(self) -> QuorumSystem[T]:
return f'Strategy(reads={non_zero_reads}, writes={non_zero_writes})' return self.qs
def node(self, x: T) -> Node[T]:
return self.qs.node(x)
def nodes(self) -> Set[Node[T]]:
return self.qs.nodes()
def get_read_quorum(self) -> Set[T]: def get_read_quorum(self) -> Set[T]:
return np.random.choice(self.reads, p=self.read_weights) return set(np.random.choice(list(self.sigma_r.keys()),
p=list(self.sigma_r.values())))
def get_write_quorum(self) -> Set[T]: def get_write_quorum(self) -> Set[T]:
return np.random.choice(self.writes, p=self.write_weights) return set(np.random.choice(list(self.sigma_w.keys()),
p=list(self.sigma_w.values())))
def load(self, def load(self,
read_fraction: Optional[Distribution] = None, read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \ write_fraction: Optional[Distribution] = None) -> float:
-> float:
d = distribution.canonicalize_rw(read_fraction, write_fraction) d = distribution.canonicalize_rw(read_fraction, write_fraction)
return sum(weight * self._load(fr) return sum(p * self._load(fr) for (fr, p) in d.items())
for (fr, weight) in d.items())
# TODO(mwhittaker): Rename throughput.
def capacity(self, def capacity(self,
read_fraction: Optional[Distribution] = None, read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \ write_fraction: Optional[Distribution] = None) -> float:
-> float:
return 1 / self.load(read_fraction, write_fraction) return 1 / self.load(read_fraction, write_fraction)
def network_load(self, def network_load(self,
read_fraction: Optional[Distribution] = None, read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) -> float: write_fraction: Optional[Distribution] = None) -> float:
d = distribution.canonicalize_rw(read_fraction, write_fraction) d = distribution.canonicalize_rw(read_fraction, write_fraction)
fr = sum(weight * f for (f, weight) in d.items()) fr = sum(p * fr for (fr, p) in d.items())
read_network_load = fr * sum( reads = fr * sum(p * len(rq) for (rq, p) in self.sigma_r.items())
len(rq) * p writes = (1 - fr) * sum(p * len(wq) for (wq, p) in self.sigma_w.items())
for (rq, p) in zip(self.reads, self.read_weights) return reads + writes
)
write_network_load = (1 - fr) * sum(
len(wq) * p
for (wq, p) in zip(self.writes, self.write_weights)
)
return read_network_load + write_network_load
# mypy doesn't like calling sum with timedeltas.
@no_type_check
def latency(self, def latency(self,
read_fraction: Optional[Distribution] = None, read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \ write_fraction: Optional[Distribution] = None) \
-> datetime.timedelta: -> datetime.timedelta:
d = distribution.canonicalize_rw(read_fraction, write_fraction) d = distribution.canonicalize_rw(read_fraction, write_fraction)
fr = sum(weight * f for (f, weight) in d.items()) fr = sum(p * fr for (fr, p) in d.items())
read_latency = fr * sum(( reads = fr * sum((
self.qs._read_quorum_latency(quorum) * p # type: ignore p * self.qs._read_quorum_latency({self.node(x) for x in rq})
for (rq, p) in zip(self.reads, self.read_weights) for (rq, p) in self.sigma_r.items()
for quorum in [{self.qs.x_to_node[x] for x in rq}] ), datetime.timedelta(seconds=0))
), datetime.timedelta(seconds=0)) # type: ignore
write_latency = (1 - fr) * sum(( writes = (1 - fr) * sum((
self.qs._write_quorum_latency(quorum) * p # type: ignore p * self.qs._write_quorum_latency({self.node(x) for x in wq})
for (wq, p) in zip(self.writes, self.write_weights) for (wq, p) in self.sigma_w.items()
for quorum in [{self.qs.x_to_node[x] for x in wq}] ), datetime.timedelta(seconds=0))
), datetime.timedelta(seconds=0)) # type:ignore
return read_latency + write_latency # type: ignore return reads + writes
def node_load(self, def node_load(self,
node: Node[T], node: Node[T],
read_fraction: Optional[Distribution] = None, read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \ write_fraction: Optional[Distribution] = None) -> float:
-> float:
d = distribution.canonicalize_rw(read_fraction, write_fraction) d = distribution.canonicalize_rw(read_fraction, write_fraction)
return sum(weight * self._node_load(node.x, fr) return sum(p * self._node_load(node, fr) for (fr, p) in d.items())
for (fr, weight) in d.items())
def node_utilization(self, def node_utilization(self,
node: Node[T], node: Node[T],
read_fraction: Optional[Distribution] = None, read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \ write_fraction: Optional[Distribution] = None) \
-> float: -> float:
# TODO(mwhittaker): Implement. d = distribution.canonicalize_rw(read_fraction, write_fraction)
return 0.0 return sum(p * self._node_utilization(node, fr)
for (fr, p) in d.items())
def node_throghput(self, def node_throughput(self,
node: Node[T], node: Node[T],
read_fraction: Optional[Distribution] = None, read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \ write_fraction: Optional[Distribution] = None) -> float:
-> float: d = distribution.canonicalize_rw(read_fraction, write_fraction)
# TODO(mwhittaker): Implement. return sum(p * self._node_throughput(node, fr) for (fr, p) in d.items())
return 0.0
def _node_load(self, x: T, fr: float) -> float:
"""
_node_load returns the load on x given a fixed read fraction fr.
"""
fw = 1 - fr
node = self.qs.x_to_node[x]
return (fr * self.unweighted_read_load[x] / node.read_capacity +
fw * self.unweighted_write_load[x] / node.write_capacity)
def _load(self, fr: float) -> float: def _load(self, fr: float) -> float:
""" return max(self._node_load(node, fr) for node in self.nodes())
_load returns the load given a fixed read fraction fr.
""" def _node_load(self, node: Node[T], fr: float) -> float:
return max(self._node_load(node.x, fr) for node in self.qs.nodes()) fw = 1 - fr
return (fr * self.x_read_probability[node.x] / node.read_capacity +
fw * self.x_write_probability[node.x] / node.write_capacity)
def _node_utilization(self, node: Node[T], fr: float) -> float:
return self._node_load(node, fr) / self._load(fr)
def _node_throughput(self, node: Node[T], fr: float) -> float:
cap = 1 / self._load(fr)
fw = 1 - fr
return cap * (fr * self.x_read_probability[node.x] +
fw * self.x_write_probability[node.x])

View file

@ -4,7 +4,7 @@ from .distribution import Distribution
from .expr import Node from .expr import Node
from .geometry import Point, Segment from .geometry import Point, Segment
from .quorum_system import Strategy from .quorum_system import Strategy
from typing import Dict, List, Optional, Set, Tuple, TypeVar from typing import Dict, FrozenSet, List, Optional, Set, Tuple, TypeVar
import collections import collections
import matplotlib import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -107,19 +107,18 @@ def _plot_node_load_on(ax: plt.Axes,
x_index = {x: i for (i, x) in enumerate(x_list)} x_index = {x: i for (i, x) in enumerate(x_list)}
x_ticks = list(range(len(x_list))) x_ticks = list(range(len(x_list)))
def one_hot(quorum: Set[T]) -> np.array: def one_hot(quorum: FrozenSet[T]) -> np.array:
bar_heights = np.zeros(len(x_list)) bar_heights = np.zeros(len(x_list))
for x in quorum: for x in quorum:
bar_heights[x_index[x]] = 1 bar_heights[x_index[x]] = 1
return bar_heights return bar_heights
def plot_quorums(quorums: List[Set[T]], def plot_quorums(sigma: Dict[FrozenSet[T], float],
weights: List[float],
fraction: float, fraction: float,
bottoms: np.array, bottoms: np.array,
capacities: np.array, capacities: np.array,
cmap: matplotlib.colors.Colormap): cmap: matplotlib.colors.Colormap):
for (i, (quorum, weight)) in enumerate(zip(quorums, weights)): for (i, (quorum, weight)) in enumerate(sigma.items()):
bar_heights = scale * fraction * weight * one_hot(quorum) bar_heights = scale * fraction * weight * one_hot(quorum)
if scale_by_node_capacity: if scale_by_node_capacity:
bar_heights /= capacities bar_heights /= capacities
@ -127,7 +126,7 @@ def _plot_node_load_on(ax: plt.Axes,
ax.bar(x_ticks, ax.bar(x_ticks,
bar_heights, bar_heights,
bottom=bottoms, bottom=bottoms,
color=cmap(0.75 - i * 0.5 / len(quorums)), color=cmap(0.75 - i * 0.5 / len(sigma)),
edgecolor='white', width=0.8) edgecolor='white', width=0.8)
for j, (bar_height, bottom) in enumerate(zip(bar_heights, bottoms)): for j, (bar_height, bottom) in enumerate(zip(bar_heights, bottoms)):
@ -142,9 +141,9 @@ def _plot_node_load_on(ax: plt.Axes,
read_capacities = np.array([node.read_capacity for node in nodes]) read_capacities = np.array([node.read_capacity for node in nodes])
write_capacities = np.array([node.write_capacity for node in nodes]) write_capacities = np.array([node.write_capacity for node in nodes])
bottoms = np.zeros(len(x_list)) bottoms = np.zeros(len(x_list))
plot_quorums(sigma.reads, sigma.read_weights, fr, bottoms, read_capacities, plot_quorums(sigma.sigma_r, fr, bottoms, read_capacities,
matplotlib.cm.get_cmap('Reds')) matplotlib.cm.get_cmap('Reds'))
plot_quorums(sigma.writes, sigma.write_weights, fw, bottoms, plot_quorums(sigma.sigma_w, fw, bottoms,
write_capacities, matplotlib.cm.get_cmap('Blues')) write_capacities, matplotlib.cm.get_cmap('Blues'))
ax.set_xticks(x_ticks) ax.set_xticks(x_ticks)
ax.set_xticklabels(str(x) for x in x_list) ax.set_xticklabels(str(x) for x in x_list)

0
tests/__init__.py Normal file
View file

162
tests/test_geometry.py Normal file
View file

@ -0,0 +1,162 @@
from quorums import *
from quorums.geometry import *
from typing import Any, Callable, List, NamedTuple, Optional, Tuple
import unittest
class TestGeometry(unittest.TestCase):
def test_eq(self):
l = Point(0, 1)
r = Point(1, 1)
m = Point(0.5, 0.5)
self.assertEqual(Segment(l, r), Segment(l, r))
self.assertNotEqual(Segment(l, r), Segment(l, m))
def test_compatible(self):
s1 = Segment(Point(0, 1), Point(1, 2))
s2 = Segment(Point(0, 2), Point(1, 1))
s3 = Segment(Point(0.5, 2), Point(1, 1))
self.assertTrue(s1.compatible(s2))
self.assertTrue(s2.compatible(s1))
self.assertFalse(s1.compatible(s3))
self.assertFalse(s3.compatible(s1))
self.assertFalse(s2.compatible(s3))
self.assertFalse(s3.compatible(s2))
def test_call(self):
segment = Segment(Point(0, 0), Point(1, 1))
for x in [0.0, 0.25, 0.5, 0.75, 1.0]:
self.assertEqual(segment(x), x)
segment = Segment(Point(0, 0), Point(1, 2))
for x in [0.0, 0.25, 0.5, 0.75, 1.0]:
self.assertEqual(segment(x), 2*x)
segment = Segment(Point(1, 2), Point(3, 6))
for x in [1.0, 1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 2.75, 3.0]:
self.assertEqual(segment(x), 2*x)
segment = Segment(Point(0, 1), Point(1, 0))
for x in [0.0, 0.25, 0.5, 0.75, 1.0]:
self.assertEqual(segment(x), 1 - x)
def test_slope(self):
self.assertEqual(Segment(Point(0, 0), Point(1, 1)).slope(), 1.0)
self.assertEqual(Segment(Point(0, 1), Point(1, 2)).slope(), 1.0)
self.assertEqual(Segment(Point(1, 1), Point(2, 2)).slope(), 1.0)
self.assertEqual(Segment(Point(1, 1), Point(2, 3)).slope(), 2.0)
self.assertEqual(Segment(Point(1, 1), Point(2, 0)).slope(), -1.0)
def test_above(self):
s1 = Segment(Point(0, 0), Point(1, 0.5))
s2 = Segment(Point(0, 0.5), Point(1, 2))
s3 = Segment(Point(0, 1.5), Point(1, 0.5))
self.assertFalse(s1.above(s1))
self.assertFalse(s1.above(s2))
self.assertFalse(s1.above(s3))
self.assertTrue(s2.above(s1))
self.assertFalse(s2.above(s2))
self.assertFalse(s2.above(s3))
self.assertTrue(s3.above(s1))
self.assertFalse(s3.above(s2))
self.assertFalse(s3.above(s3))
def test_above_eq(self):
s1 = Segment(Point(0, 0), Point(1, 0.5))
s2 = Segment(Point(0, 0.5), Point(1, 2))
s3 = Segment(Point(0, 1.5), Point(1, 0.5))
self.assertTrue(s1.above_eq(s1))
self.assertFalse(s1.above_eq(s2))
self.assertFalse(s1.above_eq(s3))
self.assertTrue(s2.above_eq(s1))
self.assertTrue(s2.above_eq(s2))
self.assertFalse(s2.above_eq(s3))
self.assertTrue(s3.above_eq(s1))
self.assertFalse(s3.above_eq(s2))
self.assertTrue(s3.above_eq(s3))
def test_intersects(self):
s1 = Segment(Point(0, 0), Point(1, 0.5))
s2 = Segment(Point(0, 0.5), Point(1, 2))
s3 = Segment(Point(0, 1.5), Point(1, 0.5))
self.assertTrue(s1.intersects(s1))
self.assertFalse(s1.intersects(s2))
self.assertTrue(s1.intersects(s3))
self.assertFalse(s2.intersects(s1))
self.assertTrue(s2.intersects(s2))
self.assertTrue(s2.intersects(s3))
self.assertTrue(s3.intersects(s1))
self.assertTrue(s3.intersects(s2))
self.assertTrue(s3.intersects(s3))
def test_intersection(self):
s1 = Segment(Point(0, 0), Point(1, 1))
s2 = Segment(Point(0, 1), Point(1, 0))
s3 = Segment(Point(0, 1), Point(1, 1))
s4 = Segment(Point(0, 0.25), Point(1, 0.25))
self.assertEqual(s1.intersection(s1), None)
self.assertEqual(s1.intersection(s2), Point(0.5, 0.5))
self.assertEqual(s1.intersection(s3), Point(1, 1))
self.assertEqual(s1.intersection(s4), Point(0.25, 0.25))
self.assertEqual(s2.intersection(s1), Point(0.5, 0.5))
self.assertEqual(s2.intersection(s2), None)
self.assertEqual(s2.intersection(s3), Point(0, 1))
self.assertEqual(s2.intersection(s4), Point(0.75, 0.25))
self.assertEqual(s3.intersection(s1), Point(1, 1))
self.assertEqual(s3.intersection(s2), Point(0, 1))
self.assertEqual(s3.intersection(s3), None)
self.assertEqual(s3.intersection(s4), None)
self.assertEqual(s4.intersection(s1), Point(0.25, 0.25))
self.assertEqual(s4.intersection(s2), Point(0.75, 0.25))
self.assertEqual(s4.intersection(s3), None)
self.assertEqual(s4.intersection(s4), None)
def test_max_one_segment(self):
s1 = Segment(Point(0, 0), Point(1, 1))
s2 = Segment(Point(0, 1), Point(1, 0))
s3 = Segment(Point(0, 1), Point(1, 1))
s4 = Segment(Point(0, 0.25), Point(1, 0.25))
s5 = Segment(Point(0, 0.75), Point(1, 0.75))
def is_subset(xs: List[Any], ys: List[Any]) -> bool:
return all(x in ys for x in xs)
for s in [s1, s2, s3, s4, s5]:
self.assertEqual(max_of_segments([s]), [s.l, s.r])
expected = [
([s1, s1], [(0, 0), (1, 1)]),
([s1, s2], [(0, 1), (0.5, 0.5), (1, 1)]),
([s1, s3], [(0, 1), (1, 1)]),
([s1, s4], [(0, 0.25), (0.25, 0.25), (1, 1)]),
([s1, s5], [(0, 0.75), (0.75, 0.75), (1, 1)]),
([s2, s2], [(0, 1), (1, 0)]),
([s2, s3], [(0, 1), (1, 1)]),
([s2, s4], [(0, 1), (0.75, 0.25), (1, 0.25)]),
([s2, s5], [(0, 1), (0.25, 0.75), (1, 0.75)]),
([s3, s3], [(0, 1), (1, 1)]),
([s3, s4], [(0, 1), (1, 1)]),
([s3, s5], [(0, 1), (1, 1)]),
([s4, s4], [(0, 0.25), (1, 0.25)]),
([s4, s5], [(0, 0.75), (1, 0.75)]),
([s5, s5], [(0, 0.75), (1, 0.75)]),
([s1, s2, s4], [(0, 1), (0.5, 0.5), (1, 1)]),
([s1, s2, s5], [(0, 1), (0.25, 0.75), (0.75, 0.75), (1, 1)]),
]
for segments, path in expected:
self.assertTrue(is_subset(path, max_of_segments(segments)))
self.assertTrue(is_subset(path, max_of_segments(segments[::-1])))