quoracle/quorums/strategy.py
2021-01-27 21:22:54 -08:00

245 lines
10 KiB
Python

from . import distribution
from .distribution import Distribution
from .expr import Node
from typing import Dict, Generic, List, Optional, Set, TypeVar
import collections
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
T = TypeVar('T')
class Strategy(Generic[T]):
def __init__(self,
nodes: Set[Node[T]],
reads: List[Set[T]],
read_weights: List[float],
writes: List[Set[T]],
write_weights: List[float]) -> None:
self.nodes = nodes
self.read_capacity = {node.x: node.read_capacity for node in nodes}
self.write_capacity = {node.x: node.write_capacity for node in nodes}
self.reads = reads
self.read_weights = read_weights
self.writes = writes
self.write_weights = write_weights
self.unweighted_read_load: Dict[T, float] = \
collections.defaultdict(float)
for (read_quorum, weight) in zip(self.reads, self.read_weights):
for x in read_quorum:
self.unweighted_read_load[x] += weight
self.unweighted_write_load: Dict[T, float] = \
collections.defaultdict(float)
for (write_quorum, weight) in zip(self.writes, self.write_weights):
for x in write_quorum:
self.unweighted_write_load[x] += weight
def __str__(self) -> str:
non_zero_reads = {tuple(r): p
for (r, p) in zip(self.reads, self.read_weights)
if p > 0}
non_zero_writes = {tuple(w): p
for (w, p) in zip(self.writes, self.write_weights)
if p > 0}
return f'Strategy(reads={non_zero_reads}, writes={non_zero_writes})'
def __repr__(self) -> str:
return (f'Strategy(nodes={self.nodes}, '+
f'reads={self.reads}, ' +
f'read_weights={self.read_weights},' +
f'writes={self.writes}, ' +
f'write_weights={self.write_weights})')
def get_read_quorum(self) -> Set[T]:
return np.random.choice(self.reads, p=self.read_weights)
def get_write_quorum(self) -> Set[T]:
return np.random.choice(self.writes, p=self.write_weights)
def load(self,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \
-> float:
d = distribution.canonicalize_rw(read_fraction, write_fraction)
return sum(weight * self._load(fr)
for (fr, weight) in d.items())
def node_load(self,
x: T,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \
-> float:
d = distribution.canonicalize_rw(read_fraction, write_fraction)
return sum(weight * self._node_load(x, fr)
for (fr, weight) in d.items())
def capacity(self,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \
-> float:
return 1 / self.load(read_fraction, write_fraction)
def plot_node_load(self,
filename: str,
nodes: Optional[List[Node[T]]] = None,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \
-> None:
fig, ax = plt.subplots()
self.plot_node_load_on(ax, nodes, read_fraction, write_fraction)
ax.set_xlabel('Node')
ax.set_ylabel('Load')
fig.tight_layout()
fig.savefig(filename)
def plot_node_load_on(self,
ax: plt.Axes,
nodes: Optional[List[Node[T]]] = None,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \
-> None:
self._plot_node_load_on(ax,
scale=1,
scale_by_node_capacity=True,
nodes=nodes,
read_fraction=read_fraction,
write_fraction=write_fraction)
def plot_node_capacity(self,
filename: str,
nodes: Optional[List[Node[T]]] = None,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \
-> None:
fig, ax = plt.subplots()
self.plot_node_capacity_on(ax, nodes, read_fraction, write_fraction)
ax.set_xlabel('Node')
ax.set_ylabel('Throughput at Peak Throughput')
fig.tight_layout()
fig.savefig(filename)
def plot_node_capacity_on(self,
ax: plt.Axes,
nodes: Optional[List[Node[T]]] = None,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \
-> None:
self._plot_node_load_on(ax,
scale=self.capacity(read_fraction,
write_fraction),
scale_by_node_capacity=False,
nodes=nodes,
read_fraction=read_fraction,
write_fraction=write_fraction)
def plot_node_utilization(self,
filename: str,
nodes: Optional[List[Node[T]]] = None,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \
-> None:
fig, ax = plt.subplots()
self.plot_node_utilization_on(ax, nodes, read_fraction, write_fraction)
ax.set_xlabel('Node')
ax.set_ylabel('Utilization at Peak Throughput')
fig.tight_layout()
fig.savefig(filename)
def plot_node_utilization_on(self,
ax: plt.Axes,
nodes: Optional[List[Node[T]]] = None,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \
-> None:
self._plot_node_load_on(
ax,
scale=self.capacity(read_fraction, write_fraction),
scale_by_node_capacity=True,
nodes=nodes,
read_fraction=read_fraction,
write_fraction=write_fraction)
def _node_load(self, x: T, fr: float) -> float:
"""
_node_load returns the load on x given a fixed read fraction fr.
"""
fw = 1 - fr
return (fr * self.unweighted_read_load[x] / self.read_capacity[x] +
fw * self.unweighted_write_load[x] / self.write_capacity[x])
def _load(self, fr: float) -> float:
"""
_load returns the load given a fixed read fraction fr.
"""
return max(self._node_load(node.x, fr) for node in self.nodes)
def _plot_node_load_on(
self,
ax: plt.Axes,
scale: float,
scale_by_node_capacity: bool,
nodes: Optional[List[Node[T]]] = None,
read_fraction: Optional[Distribution] = None,
write_fraction: Optional[Distribution] = None) \
-> None:
nodes = nodes or list(self.nodes)
d = distribution.canonicalize_rw(read_fraction, write_fraction)
x_list = [node.x for node in nodes]
x_index = {x: i for (i, x) in enumerate(x_list)}
x_ticks = list(range(len(x_list)))
def read_quorum_to_bar_heights(quorum: Set[T]) -> np.array:
bar_heights = np.zeros(len(x_list))
for x in quorum:
bar_heights[x_index[x]] = 1
if scale_by_node_capacity:
bar_heights[x_index[x]] /= self.read_capacity[x]
return bar_heights
def write_quorum_to_bar_heights(quorum: Set[T]) -> np.array:
bar_heights = np.zeros(len(x_list))
for x in quorum:
bar_heights[x_index[x]] = 1
if scale_by_node_capacity:
bar_heights[x_index[x]] /= self.write_capacity[x]
return bar_heights
bottoms = np.zeros(len(x_list))
fr = sum(weight * fr for (fr, weight) in d.items())
read_cmap = matplotlib.cm.get_cmap('Reds')
for (i, (rq, weight)) in enumerate(zip(self.reads, self.read_weights)):
bar_heights = scale * fr * weight * read_quorum_to_bar_heights(rq)
ax.bar(x_ticks,
bar_heights,
bottom=bottoms,
color=read_cmap(0.75 - i * 0.5 / len(self.reads)),
edgecolor='white', width=0.8)
for j, (bar_height, bottom) in enumerate(zip(bar_heights, bottoms)):
if bar_height != 0:
ax.text(x_ticks[j], bottom + bar_height / 2, i,
ha='center', va='center')
bottoms += bar_heights
fw = 1 - fr
write_cmap = matplotlib.cm.get_cmap('Blues')
for (i, (wq, weight)) in enumerate(zip(self.writes, self.write_weights)):
bar_heights = scale * fw * weight * write_quorum_to_bar_heights(wq)
ax.bar(x_ticks,
bar_heights,
bottom=bottoms,
color=write_cmap(0.75 - i * 0.5 / len(self.writes)),
edgecolor='white', width=0.8)
for j, (bar_height, bottom) in enumerate(zip(bar_heights, bottoms)):
if bar_height != 0:
ax.text(x_ticks[j], bottom + bar_height / 2, i,
ha='center', va='center')
bottoms += bar_heights
ax.set_xticks(x_ticks)
ax.set_xticklabels(str(x) for x in x_list)