Added some more skeleton around strategies.

This commit is contained in:
Michael Whittaker 2021-01-20 17:11:59 -08:00
parent 56992f1c8d
commit efbeb8dc44

View file

@ -1,5 +1,8 @@
from typing import Iterator, Generic, List, Optional, Set, TypeVar
from typing import (Dict, Iterator, Generic, List, Optional, Set, Tuple,
TypeVar, Union)
import itertools
import numpy as np
import pulp
T = TypeVar('T')
@ -139,15 +142,53 @@ def _or(lhs: Expr[T], rhs: Expr[T]) -> 'Or[T]':
return Or([lhs, rhs])
def choose(k: int, es: List[Expr[T]]) -> Choose[T]:
return Choose(k, es)
def choose(k: int, es: List[Expr[T]]) -> Expr[T]:
if k == 1:
return Or(es)
elif k == len(es):
return And(es)
else:
return Choose(k, es)
def majority(es: List[Expr[T]]) -> Choose[T]:
return Choose(len(es) // 2 + 1, es)
def majority(es: List[Expr[T]]) -> Expr[T]:
return choose(len(es) // 2 + 1, es)
class QuorumSystem:
Distribution = Union[int, float, Dict[float, float], List[Tuple[float, float]]]
def _canonicalize_distribution(d: Distribution) -> Dict[float, float]:
if isinstance(d, int):
if d < 0 or d > 1:
raise ValueError('distribution must be in the range [0, 1]')
return {float(d): 1.}
elif isinstance(d, float):
if d < 0 or d > 1:
raise ValueError('distribution must be in the range [0, 1]')
return {d: 1.}
elif isinstance(d, dict):
if len(d) == 0:
raise ValueError('distribution cannot empty')
if any(weight < 0 for weight in d.values()):
raise ValueError('distribution cannot have negative weights')
total_weight = sum(d.values())
if total_weight == 0:
raise ValueError('distribution cannot have zero weight')
return {float(f): weight / total_weight
for (f, weight) in d.items()
if weight > 0}
elif isinstance(d, list):
return _canonicalize_distribution({f: weight for (f, weight) in d})
else:
raise ValueError('distribution must be an int, a float, a Dict[float, '
'float] or a List[Tuple[float, float]]')
class QuorumSystem(Generic[T]):
def __init__(self, reads: Optional[Expr[T]] = None,
writes: Optional[Expr[T]] = None) -> None:
if reads is not None and writes is not None:
@ -170,17 +211,55 @@ class QuorumSystem:
def __repr__(self) -> str:
return f'QuorumSystem(reads={self.reads}, writes={self.writes})'
def strategy(self, read_fraction: Distribution) -> 'Strategy[T]':
# TODO(mwhittaker): Implement.
reads = list(self.read_quorums())
writes = list(self.write_quorums())
return ExplicitStrategy(reads, [1 / len(reads)] * len(reads),
writes, [1 / len(writes)] * len(writes))
def is_read_quorum(self, xs: Set[T]) -> bool:
return self.reads.is_quorum(xs)
def read_quorums(self) -> Iterator[Set[T]]:
return self.reads.quorums()
def write_quorums(self) -> Iterator[Set[T]]:
return self.writes.quorums()
def is_read_quorum(self, xs: Set[T]) -> bool:
return self.reads.is_quorum(xs)
def is_write_quorum(self, xs: Set[T]) -> bool:
return self.writes.is_quorum(xs)
class Strategy(Generic[T]):
def load(self, read_fraction: Distribution) -> int:
raise NotImplementedError
def get_read_quorum(self) -> Set[T]:
raise NotImplementedError
def get_write_quorum(self) -> Set[T]:
raise NotImplementedError
class ExplicitStrategy(Strategy[T]):
def __init__(self,
reads: List[Set[T]],
read_weights: List[float],
writes: List[Set[T]],
write_weights: List[float]) -> None:
self.reads = reads
self.read_weights = read_weights
self.writes = writes
self.write_weights = write_weights
# TODO(mwhittaker): Implement __str__ and __repr__.
def load(self, read_fraction: Distribution) -> int:
raise NotImplementedError
def get_read_quorum(self) -> Set[T]:
return np.random.choice(self.reads, p=self.read_weights)
def get_write_quorum(self) -> Set[T]:
return np.random.choice(self.writes, p=self.write_weights)
a = Node('a')
@ -188,18 +267,14 @@ b = Node('b')
c = Node('c')
d = Node('d')
e = Node('e')
f = Node('g')
f = Node('f')
g = Node('g')
h = Node('h')
i = Node('i')
disjunction = a + b + c
conjunction = disjunction * disjunction * disjunction
print(conjunction)
print(conjunction.dual())
print(conjunction.dual().dual())
print(QuorumSystem(reads=conjunction))
grid = QuorumSystem(reads=a*b*c + d*e*f + g*h*i)
sigma = grid.strategy(0.1)
for _ in range(10):
print(sigma.get_write_quorum())
# - num_quorums
# - has dups?