Fixes for type check errors exposed by pyright

Use SupportsLessThan to indicate that TypeVar T should be
comparable.
This commit is contained in:
Arun Sharma 2021-04-03 10:44:33 -07:00
parent f07680fb52
commit 7f980997ef
6 changed files with 37 additions and 33 deletions

View file

@ -39,12 +39,12 @@ def main(output_filename: str) -> None:
for fr in dist.keys():
sigma = qs.strategy(read_fraction=fr)
ys = [sigma.capacity(read_fraction=x) for x in xs]
ax.plot(xs, ys, '--', label=str(f'$\sigma_{{{fr}}}$'), linewidth=1,
ax.plot(xs, ys, '--', label=str(f'$\\sigma_{{{fr}}}$'), linewidth=1,
marker=next(markers), markevery=25, markersize=4, alpha=0.75)
sigma = qs.strategy(read_fraction=dist)
ys = [sigma.capacity(read_fraction=x) for x in xs]
ax.plot(xs, ys, label='$\sigma$', linewidth=1.5,
ax.plot(xs, ys, label='$\\sigma$', linewidth=1.5,
marker=next(markers), markevery=25, markersize=4)
ax.legend(ncol=3, loc='lower center', bbox_to_anchor=(0.5, 1.0))

View file

@ -1,14 +1,17 @@
from typing import Dict, Iterator, Generic, List, Optional, Set, TypeVar
from typing import Any, Dict, Iterator, Generic, List, Optional, Protocol, Set, TypeVar
import datetime
import itertools
import pulp
T = TypeVar('T')
class SupportsLessThan(Protocol):
def __lt__(self, __other: Any) -> bool: ...
T = TypeVar("T", bound=SupportsLessThan)
def _min_hitting_set(sets: Iterator[Set[T]]) -> int:
x_vars: Dict[T, pulp.LpVariable] = dict()
def _min_hitting_set(sets: Iterator[Set]) -> int:
x_vars: Dict[Any, pulp.LpVariable] = dict()
next_id = itertools.count()
problem = pulp.LpProblem("min_hitting_set", pulp.LpMinimize)
@ -131,6 +134,9 @@ class Node(Expr[T]):
def __repr__(self) -> str:
return f'Node({self.x})'
def __lt__(self, other) -> bool:
return self.x < other.x
def quorums(self) -> Iterator[Set[T]]:
yield {self.x}

View file

@ -1,9 +1,20 @@
from . import distribution
from . import geometry
from .distribution import Distribution
from .expr import Expr, Node
from .expr import Expr, Node, T
from .geometry import Point, Segment
from typing import *
from typing import (
Any,
Dict,
Iterator,
Generic,
List,
Optional,
Set,
FrozenSet,
Callable,
Tuple,
)
import collections
import datetime
import itertools
@ -11,9 +22,6 @@ import math
import pulp
import random
T = TypeVar('T')
LOAD = 'load'
NETWORK = 'network'
LATENCY = 'latency'
@ -143,7 +151,7 @@ class QuorumSystem(Generic[T]):
latency_limit: Optional[datetime.timedelta] = None,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None,
f: int = 0) -> float:
f: int = 0) -> datetime.timedelta:
return self.strategy(
optimize=optimize,
load_limit=load_limit,
@ -606,9 +614,7 @@ class Strategy(Generic[T]):
for x in write_quorum:
self.x_write_probability[x] += weight
@no_type_check
def __str__(self) -> str:
# T may not comparable, so mypy complains about this sort.
reads = {tuple(sorted(rq)): p for (rq, p) in self.sigma_r.items()}
writes = {tuple(sorted(wq)): p for (wq, p) in self.sigma_w.items()}
return f'Strategy(reads={reads}, writes={writes})'
@ -651,8 +657,6 @@ class Strategy(Generic[T]):
writes = (1 - fr) * sum(p * len(wq) for (wq, p) in self.sigma_w.items())
return reads + writes
# mypy doesn't like calling sum with timedeltas.
@no_type_check
def latency(self,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \
@ -661,16 +665,16 @@ class Strategy(Generic[T]):
fr = sum(p * fr for (fr, p) in d.items())
reads = fr * sum((
p * self.qs._read_quorum_latency({self.node(x) for x in rq})
p * self.qs._read_quorum_latency({self.node(x) for x in rq}) # type: ignore[misc]
for (rq, p) in self.sigma_r.items()
), datetime.timedelta(seconds=0))
), datetime.timedelta(seconds=0)) # type: ignore[arg-type]
writes = (1 - fr) * sum((
p * self.qs._write_quorum_latency({self.node(x) for x in wq})
p * self.qs._write_quorum_latency({self.node(x) for x in wq}) # type: ignore[misc]
for (wq, p) in self.sigma_w.items()
), datetime.timedelta(seconds=0))
), datetime.timedelta(seconds=0)) # type: ignore[arg-type]
return reads + writes
return reads + writes # type: ignore[return-value]
def node_load(self,
node: Node[T],

View file

@ -1,15 +1,12 @@
from .distribution import Distribution
from .expr import choose, Expr, Node
from .expr import choose, Expr, Node, T
from .quorum_system import (LATENCY, LOAD, NETWORK, NoStrategyFoundError,
QuorumSystem, Strategy, Tuple)
from typing import Iterator, List, Optional, TypeVar
from typing import Iterator, List, Optional
import datetime
import itertools
T = TypeVar('T')
class NoQuorumSystemFoundError(ValueError):
pass

View file

@ -1,7 +1,7 @@
from . import distribution
from . import geometry
from .distribution import Distribution
from .expr import Node
from .expr import Node, T
from .geometry import Point, Segment
from .quorum_system import Strategy
from typing import Dict, FrozenSet, List, Optional, Set, Tuple, TypeVar
@ -10,9 +10,6 @@ import matplotlib
import matplotlib.pyplot as plt
T = TypeVar('T')
def plot_node_load(filename: str,
strategy: Strategy[T],
nodes: Optional[List[Node[T]]] = None,

View file

@ -1,8 +1,8 @@
from quoracle import *
from quoracle.expr import *
from typing import Any, FrozenSet
import unittest
from quoracle.expr import Expr, Node, choose
from typing import Any, List, Set, FrozenSet
class TestExpr(unittest.TestCase):
def test_quorums(self):
def assert_equal(e: Expr[str], xs: List[Set[str]]) -> None: