Fixes for type check errors exposed by pyright

Use SupportsLessThan to indicate that TypeVar T should be
comparable.
This commit is contained in:
Arun Sharma 2021-04-03 10:44:33 -07:00
parent f07680fb52
commit 7f980997ef
6 changed files with 37 additions and 33 deletions

View file

@ -39,12 +39,12 @@ def main(output_filename: str) -> None:
for fr in dist.keys(): for fr in dist.keys():
sigma = qs.strategy(read_fraction=fr) sigma = qs.strategy(read_fraction=fr)
ys = [sigma.capacity(read_fraction=x) for x in xs] ys = [sigma.capacity(read_fraction=x) for x in xs]
ax.plot(xs, ys, '--', label=str(f'$\sigma_{{{fr}}}$'), linewidth=1, ax.plot(xs, ys, '--', label=str(f'$\\sigma_{{{fr}}}$'), linewidth=1,
marker=next(markers), markevery=25, markersize=4, alpha=0.75) marker=next(markers), markevery=25, markersize=4, alpha=0.75)
sigma = qs.strategy(read_fraction=dist) sigma = qs.strategy(read_fraction=dist)
ys = [sigma.capacity(read_fraction=x) for x in xs] ys = [sigma.capacity(read_fraction=x) for x in xs]
ax.plot(xs, ys, label='$\sigma$', linewidth=1.5, ax.plot(xs, ys, label='$\\sigma$', linewidth=1.5,
marker=next(markers), markevery=25, markersize=4) marker=next(markers), markevery=25, markersize=4)
ax.legend(ncol=3, loc='lower center', bbox_to_anchor=(0.5, 1.0)) ax.legend(ncol=3, loc='lower center', bbox_to_anchor=(0.5, 1.0))

View file

@ -1,14 +1,17 @@
from typing import Dict, Iterator, Generic, List, Optional, Set, TypeVar from typing import Any, Dict, Iterator, Generic, List, Optional, Protocol, Set, TypeVar
import datetime import datetime
import itertools import itertools
import pulp import pulp
T = TypeVar('T') class SupportsLessThan(Protocol):
def __lt__(self, __other: Any) -> bool: ...
T = TypeVar("T", bound=SupportsLessThan)
def _min_hitting_set(sets: Iterator[Set[T]]) -> int: def _min_hitting_set(sets: Iterator[Set]) -> int:
x_vars: Dict[T, pulp.LpVariable] = dict() x_vars: Dict[Any, pulp.LpVariable] = dict()
next_id = itertools.count() next_id = itertools.count()
problem = pulp.LpProblem("min_hitting_set", pulp.LpMinimize) problem = pulp.LpProblem("min_hitting_set", pulp.LpMinimize)
@ -131,6 +134,9 @@ class Node(Expr[T]):
def __repr__(self) -> str: def __repr__(self) -> str:
return f'Node({self.x})' return f'Node({self.x})'
def __lt__(self, other) -> bool:
return self.x < other.x
def quorums(self) -> Iterator[Set[T]]: def quorums(self) -> Iterator[Set[T]]:
yield {self.x} yield {self.x}

View file

@ -1,9 +1,20 @@
from . import distribution from . import distribution
from . import geometry from . import geometry
from .distribution import Distribution from .distribution import Distribution
from .expr import Expr, Node from .expr import Expr, Node, T
from .geometry import Point, Segment from .geometry import Point, Segment
from typing import * from typing import (
Any,
Dict,
Iterator,
Generic,
List,
Optional,
Set,
FrozenSet,
Callable,
Tuple,
)
import collections import collections
import datetime import datetime
import itertools import itertools
@ -11,9 +22,6 @@ import math
import pulp import pulp
import random import random
T = TypeVar('T')
LOAD = 'load' LOAD = 'load'
NETWORK = 'network' NETWORK = 'network'
LATENCY = 'latency' LATENCY = 'latency'
@ -143,7 +151,7 @@ class QuorumSystem(Generic[T]):
latency_limit: Optional[datetime.timedelta] = None, latency_limit: Optional[datetime.timedelta] = None,
read_fraction: Optional[Distribution] = None, read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None, write_fraction: Optional[Distribution] = None,
f: int = 0) -> float: f: int = 0) -> datetime.timedelta:
return self.strategy( return self.strategy(
optimize=optimize, optimize=optimize,
load_limit=load_limit, load_limit=load_limit,
@ -606,9 +614,7 @@ class Strategy(Generic[T]):
for x in write_quorum: for x in write_quorum:
self.x_write_probability[x] += weight self.x_write_probability[x] += weight
@no_type_check
def __str__(self) -> str: def __str__(self) -> str:
# T may not comparable, so mypy complains about this sort.
reads = {tuple(sorted(rq)): p for (rq, p) in self.sigma_r.items()} reads = {tuple(sorted(rq)): p for (rq, p) in self.sigma_r.items()}
writes = {tuple(sorted(wq)): p for (wq, p) in self.sigma_w.items()} writes = {tuple(sorted(wq)): p for (wq, p) in self.sigma_w.items()}
return f'Strategy(reads={reads}, writes={writes})' return f'Strategy(reads={reads}, writes={writes})'
@ -651,8 +657,6 @@ class Strategy(Generic[T]):
writes = (1 - fr) * sum(p * len(wq) for (wq, p) in self.sigma_w.items()) writes = (1 - fr) * sum(p * len(wq) for (wq, p) in self.sigma_w.items())
return reads + writes return reads + writes
# mypy doesn't like calling sum with timedeltas.
@no_type_check
def latency(self, def latency(self,
read_fraction: Optional[Distribution] = None, read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \ write_fraction: Optional[Distribution] = None) \
@ -661,16 +665,16 @@ class Strategy(Generic[T]):
fr = sum(p * fr for (fr, p) in d.items()) fr = sum(p * fr for (fr, p) in d.items())
reads = fr * sum(( reads = fr * sum((
p * self.qs._read_quorum_latency({self.node(x) for x in rq}) p * self.qs._read_quorum_latency({self.node(x) for x in rq}) # type: ignore[misc]
for (rq, p) in self.sigma_r.items() for (rq, p) in self.sigma_r.items()
), datetime.timedelta(seconds=0)) ), datetime.timedelta(seconds=0)) # type: ignore[arg-type]
writes = (1 - fr) * sum(( writes = (1 - fr) * sum((
p * self.qs._write_quorum_latency({self.node(x) for x in wq}) p * self.qs._write_quorum_latency({self.node(x) for x in wq}) # type: ignore[misc]
for (wq, p) in self.sigma_w.items() for (wq, p) in self.sigma_w.items()
), datetime.timedelta(seconds=0)) ), datetime.timedelta(seconds=0)) # type: ignore[arg-type]
return reads + writes return reads + writes # type: ignore[return-value]
def node_load(self, def node_load(self,
node: Node[T], node: Node[T],

View file

@ -1,15 +1,12 @@
from .distribution import Distribution from .distribution import Distribution
from .expr import choose, Expr, Node from .expr import choose, Expr, Node, T
from .quorum_system import (LATENCY, LOAD, NETWORK, NoStrategyFoundError, from .quorum_system import (LATENCY, LOAD, NETWORK, NoStrategyFoundError,
QuorumSystem, Strategy, Tuple) QuorumSystem, Strategy, Tuple)
from typing import Iterator, List, Optional, TypeVar from typing import Iterator, List, Optional
import datetime import datetime
import itertools import itertools
T = TypeVar('T')
class NoQuorumSystemFoundError(ValueError): class NoQuorumSystemFoundError(ValueError):
pass pass

View file

@ -1,7 +1,7 @@
from . import distribution from . import distribution
from . import geometry from . import geometry
from .distribution import Distribution from .distribution import Distribution
from .expr import Node from .expr import Node, T
from .geometry import Point, Segment from .geometry import Point, Segment
from .quorum_system import Strategy from .quorum_system import Strategy
from typing import Dict, FrozenSet, List, Optional, Set, Tuple, TypeVar from typing import Dict, FrozenSet, List, Optional, Set, Tuple, TypeVar
@ -10,9 +10,6 @@ import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
T = TypeVar('T')
def plot_node_load(filename: str, def plot_node_load(filename: str,
strategy: Strategy[T], strategy: Strategy[T],
nodes: Optional[List[Node[T]]] = None, nodes: Optional[List[Node[T]]] = None,

View file

@ -1,8 +1,8 @@
from quoracle import *
from quoracle.expr import *
from typing import Any, FrozenSet
import unittest import unittest
from quoracle.expr import Expr, Node, choose
from typing import Any, List, Set, FrozenSet
class TestExpr(unittest.TestCase): class TestExpr(unittest.TestCase):
def test_quorums(self): def test_quorums(self):
def assert_equal(e: Expr[str], xs: List[Set[str]]) -> None: def assert_equal(e: Expr[str], xs: List[Set[str]]) -> None: