From 7f980997efa9d9f3d5e7773a5333307f2db76020 Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Sat, 3 Apr 2021 10:44:33 -0700 Subject: [PATCH] Fixes for type check errors exposed by pyright Use SupportsLessThan to indicate that TypeVar T should be comparable. --- examples/plot_workload_distribution.py | 4 +-- quoracle/expr.py | 14 ++++++++--- quoracle/quorum_system.py | 34 ++++++++++++++------------ quoracle/search.py | 7 ++---- quoracle/viz.py | 5 +--- tests/test_expr.py | 6 ++--- 6 files changed, 37 insertions(+), 33 deletions(-) diff --git a/examples/plot_workload_distribution.py b/examples/plot_workload_distribution.py index 71f4f95..11510c1 100644 --- a/examples/plot_workload_distribution.py +++ b/examples/plot_workload_distribution.py @@ -39,12 +39,12 @@ def main(output_filename: str) -> None: for fr in dist.keys(): sigma = qs.strategy(read_fraction=fr) ys = [sigma.capacity(read_fraction=x) for x in xs] - ax.plot(xs, ys, '--', label=str(f'$\sigma_{{{fr}}}$'), linewidth=1, + ax.plot(xs, ys, '--', label=str(f'$\\sigma_{{{fr}}}$'), linewidth=1, marker=next(markers), markevery=25, markersize=4, alpha=0.75) sigma = qs.strategy(read_fraction=dist) ys = [sigma.capacity(read_fraction=x) for x in xs] - ax.plot(xs, ys, label='$\sigma$', linewidth=1.5, + ax.plot(xs, ys, label='$\\sigma$', linewidth=1.5, marker=next(markers), markevery=25, markersize=4) ax.legend(ncol=3, loc='lower center', bbox_to_anchor=(0.5, 1.0)) diff --git a/quoracle/expr.py b/quoracle/expr.py index f2c11f7..0890f62 100644 --- a/quoracle/expr.py +++ b/quoracle/expr.py @@ -1,14 +1,17 @@ -from typing import Dict, Iterator, Generic, List, Optional, Set, TypeVar +from typing import Any, Dict, Iterator, Generic, List, Optional, Protocol, Set, TypeVar import datetime import itertools import pulp -T = TypeVar('T') +class SupportsLessThan(Protocol): + def __lt__(self, __other: Any) -> bool: ... + +T = TypeVar("T", bound=SupportsLessThan) -def _min_hitting_set(sets: Iterator[Set[T]]) -> int: - x_vars: Dict[T, pulp.LpVariable] = dict() +def _min_hitting_set(sets: Iterator[Set]) -> int: + x_vars: Dict[Any, pulp.LpVariable] = dict() next_id = itertools.count() problem = pulp.LpProblem("min_hitting_set", pulp.LpMinimize) @@ -131,6 +134,9 @@ class Node(Expr[T]): def __repr__(self) -> str: return f'Node({self.x})' + def __lt__(self, other) -> bool: + return self.x < other.x + def quorums(self) -> Iterator[Set[T]]: yield {self.x} diff --git a/quoracle/quorum_system.py b/quoracle/quorum_system.py index 0bdfa92..90bb1f7 100644 --- a/quoracle/quorum_system.py +++ b/quoracle/quorum_system.py @@ -1,9 +1,20 @@ from . import distribution from . import geometry from .distribution import Distribution -from .expr import Expr, Node +from .expr import Expr, Node, T from .geometry import Point, Segment -from typing import * +from typing import ( + Any, + Dict, + Iterator, + Generic, + List, + Optional, + Set, + FrozenSet, + Callable, + Tuple, +) import collections import datetime import itertools @@ -11,9 +22,6 @@ import math import pulp import random - -T = TypeVar('T') - LOAD = 'load' NETWORK = 'network' LATENCY = 'latency' @@ -143,7 +151,7 @@ class QuorumSystem(Generic[T]): latency_limit: Optional[datetime.timedelta] = None, read_fraction: Optional[Distribution] = None, write_fraction: Optional[Distribution] = None, - f: int = 0) -> float: + f: int = 0) -> datetime.timedelta: return self.strategy( optimize=optimize, load_limit=load_limit, @@ -606,9 +614,7 @@ class Strategy(Generic[T]): for x in write_quorum: self.x_write_probability[x] += weight - @no_type_check def __str__(self) -> str: - # T may not comparable, so mypy complains about this sort. reads = {tuple(sorted(rq)): p for (rq, p) in self.sigma_r.items()} writes = {tuple(sorted(wq)): p for (wq, p) in self.sigma_w.items()} return f'Strategy(reads={reads}, writes={writes})' @@ -651,8 +657,6 @@ class Strategy(Generic[T]): writes = (1 - fr) * sum(p * len(wq) for (wq, p) in self.sigma_w.items()) return reads + writes - # mypy doesn't like calling sum with timedeltas. - @no_type_check def latency(self, read_fraction: Optional[Distribution] = None, write_fraction: Optional[Distribution] = None) \ @@ -661,16 +665,16 @@ class Strategy(Generic[T]): fr = sum(p * fr for (fr, p) in d.items()) reads = fr * sum(( - p * self.qs._read_quorum_latency({self.node(x) for x in rq}) + p * self.qs._read_quorum_latency({self.node(x) for x in rq}) # type: ignore[misc] for (rq, p) in self.sigma_r.items() - ), datetime.timedelta(seconds=0)) + ), datetime.timedelta(seconds=0)) # type: ignore[arg-type] writes = (1 - fr) * sum(( - p * self.qs._write_quorum_latency({self.node(x) for x in wq}) + p * self.qs._write_quorum_latency({self.node(x) for x in wq}) # type: ignore[misc] for (wq, p) in self.sigma_w.items() - ), datetime.timedelta(seconds=0)) + ), datetime.timedelta(seconds=0)) # type: ignore[arg-type] - return reads + writes + return reads + writes # type: ignore[return-value] def node_load(self, node: Node[T], diff --git a/quoracle/search.py b/quoracle/search.py index 16be3f0..bae63cd 100644 --- a/quoracle/search.py +++ b/quoracle/search.py @@ -1,15 +1,12 @@ from .distribution import Distribution -from .expr import choose, Expr, Node +from .expr import choose, Expr, Node, T from .quorum_system import (LATENCY, LOAD, NETWORK, NoStrategyFoundError, QuorumSystem, Strategy, Tuple) -from typing import Iterator, List, Optional, TypeVar +from typing import Iterator, List, Optional import datetime import itertools -T = TypeVar('T') - - class NoQuorumSystemFoundError(ValueError): pass diff --git a/quoracle/viz.py b/quoracle/viz.py index 5fdc0b5..7159c75 100644 --- a/quoracle/viz.py +++ b/quoracle/viz.py @@ -1,7 +1,7 @@ from . import distribution from . import geometry from .distribution import Distribution -from .expr import Node +from .expr import Node, T from .geometry import Point, Segment from .quorum_system import Strategy from typing import Dict, FrozenSet, List, Optional, Set, Tuple, TypeVar @@ -10,9 +10,6 @@ import matplotlib import matplotlib.pyplot as plt -T = TypeVar('T') - - def plot_node_load(filename: str, strategy: Strategy[T], nodes: Optional[List[Node[T]]] = None, diff --git a/tests/test_expr.py b/tests/test_expr.py index 21b6840..a5fb441 100644 --- a/tests/test_expr.py +++ b/tests/test_expr.py @@ -1,8 +1,8 @@ -from quoracle import * -from quoracle.expr import * -from typing import Any, FrozenSet import unittest +from quoracle.expr import Expr, Node, choose +from typing import Any, List, Set, FrozenSet + class TestExpr(unittest.TestCase): def test_quorums(self): def assert_equal(e: Expr[str], xs: List[Set[str]]) -> None: