diff --git a/quorums/geometry.py b/quorums/geometry.py index 52225a0..41eef4e 100644 --- a/quorums/geometry.py +++ b/quorums/geometry.py @@ -1,4 +1,5 @@ from typing import Any, Callable, List, NamedTuple, Optional, Tuple +import math import unittest @@ -26,13 +27,20 @@ class Segment: else: return False - def compatible(self, other: 'Segment') -> float: - return self.l.x == other.l.x and self.r.x == other.r.x + def __hash__(self) -> int: + return hash((self.l, self.r)) def __call__(self, x: float) -> float: assert self.l.x <= x <= self.r.x return self.slope() * (x - self.l.x) + self.l.y + def approximately_equal(self, other: 'Segment') -> float: + return (math.isclose(self.l.y, other.l.y, rel_tol=1e-5) and + math.isclose(self.r.y, other.r.y, rel_tol=1e-5)) + + def compatible(self, other: 'Segment') -> float: + return self.l.x == other.l.x and self.r.x == other.r.x + def slope(self) -> float: return (self.r.y - self.l.y) / (self.r.x - self.l.x) @@ -72,36 +80,16 @@ def max_of_segments(segments: List[Segment]) -> List[Tuple[float, float]]: assert len({segment.l.x for segment in segments}) == 1 assert len({segment.r.x for segment in segments}) == 1 - # First, we remove any segments that are subsumed by other segments. - non_dominated: List[Segment] = [] - for segment in segments: - if any(other.above_eq(segment) for other in non_dominated): - # If this segment is dominated by another, we exclude it. - pass - else: - # Otherwise, we add this segment and remove any that it dominates. - non_dominated = [other - for other in non_dominated - if not segment.above_eq(other)] - non_dominated.append(segment) - - # Next, we start at the leftmost segment and continually jump over to the - # segment with the first intersection. - segment = max(non_dominated, key=lambda segment: segment.l.y) - path: List[Point] = [segment.l] - while True: - intersections: List[Tuple[Point, Segment]] = [] - for other in non_dominated: - p = segment.intersection(other) - if p is not None and p.x > path[-1].x: - intersections.append((p, other)) - - if len(intersections) == 0: - path.append(segment.r) - return [(p.x, p.y) for p in path] - - intersection, segment = min(intersections, key=lambda t: t[0].x) - path.append(intersection) + # We compute the x-coordinate of every intersection point. We sort the + # x-coordinates and for every x, we compute the highest line at that point. + xs: List[float] = [0.0, 1.0] + for (i, s1) in enumerate(segments): + for (j, s2) in enumerate(segments[i + 1:], i + 1): + p = s1.intersection(s2) + if p is not None: + xs.append(p.x) + xs.sort() + return [(x, max(segments, key=lambda s: s(x))(x)) for x in xs] class TestGeometry(unittest.TestCase): @@ -231,6 +219,9 @@ class TestGeometry(unittest.TestCase): s4 = Segment(Point(0, 0.25), Point(1, 0.25)) s5 = Segment(Point(0, 0.75), Point(1, 0.75)) + def is_subset(xs: List[Any], ys: List[Any]) -> bool: + return all(x in ys for x in xs) + for s in [s1, s2, s3, s4, s5]: self.assertEqual(max_of_segments([s]), [s.l, s.r]) @@ -255,8 +246,8 @@ class TestGeometry(unittest.TestCase): ([s1, s2, s5], [(0, 1), (0.25, 0.75), (0.75, 0.75), (1, 1)]), ] for segments, path in expected: - self.assertEqual(max_of_segments(segments), path, segments) - self.assertEqual(max_of_segments(segments[::-1]), path, segments) + self.assertTrue(is_subset(path, max_of_segments(segments))) + self.assertTrue(is_subset(path, max_of_segments(segments[::-1]))) if __name__ == '__main__': diff --git a/quorums/strategy.py b/quorums/strategy.py index 33a659d..0816e55 100644 --- a/quorums/strategy.py +++ b/quorums/strategy.py @@ -1,8 +1,12 @@ from . import distribution +from . import geometry from .distribution import Distribution from .expr import Node -from typing import Dict, Generic, List, Optional, Set, TypeVar +from .geometry import Point, Segment +from typing import Dict, Generic, List, Optional, Set, Tuple, TypeVar import collections +import itertools +import math import matplotlib import matplotlib.pyplot as plt import numpy as np @@ -69,12 +73,12 @@ class Strategy(Generic[T]): for (fr, weight) in d.items()) def node_load(self, - x: T, + node: Node[T], read_fraction: Optional[Distribution] = None, write_fraction: Optional[Distribution] = None) \ -> float: d = distribution.canonicalize_rw(read_fraction, write_fraction) - return sum(weight * self._node_load(x, fr) + return sum(weight * self._node_load(node.x, fr) for (fr, weight) in d.items()) def capacity(self, @@ -164,6 +168,65 @@ class Strategy(Generic[T]): read_fraction=read_fraction, write_fraction=write_fraction) + def plot_load_distribution(self, + filename: str, + nodes: Optional[List[Node[T]]] = None) \ + -> None: + fig, ax = plt.subplots() + self.plot_load_distribution_on(ax, nodes) + ax.set_xlabel('Read Fraction') + ax.set_ylabel('Load') + fig.tight_layout() + fig.savefig(filename) + + def _group(self, segments: List[Tuple[Segment, T]]) -> Dict[Segment, List[T]]: + groups: Dict[Segment, List[T]] = collections.defaultdict(list) + for segment, x in segments: + match_found = False + for other, xs in groups.items(): + if segment.approximately_equal(other): + xs.append(x) + match_found = True + break + + if not match_found: + groups[segment].append(x) + + return groups + + def plot_load_distribution_on(self, + ax: plt.Axes, + nodes: Optional[List[Node[T]]] = None) \ + -> None: + nodes = nodes or list(self.nodes) + + # We want to plot every node's load distribution. Multiple nodes might + # have the same load distribution, so we group the nodes by their + # distribution. The grouping is a little annoying because two floats + # might not be exactly equal but pretty close. + groups = self._group([ + (Segment(Point(0, self.node_load(node, read_fraction=0)), + Point(1, self.node_load(node, read_fraction=1))), node.x) + for node in nodes + ]) + + # Compute and plot the max of all segments. We increase the line + # slightly so it doesn't overlap with the other lines. + path = geometry.max_of_segments(list(groups.keys())) + ax.plot([p[0] for p in path], + [p[1] for p in path], + label='load', + linewidth=4) + + for segment, xs in groups.items(): + ax.plot([segment.l.x, segment.r.x], + [segment.l.y, segment.r.y], + '--', + label=','.join(str(x) for x in xs), + linewidth=2, + alpha=0.75) + + def _node_load(self, x: T, fr: float) -> float: """ _node_load returns the load on x given a fixed read fraction fr.