import random
from typing import Any, Optional, Iterable, Generic, TypeVar, Self
from collections import defaultdict
from collections.abc import Callable, Hashable
import numpy as np
[docs]
def equilibrated(
values: list[float],
times: Optional[list[float]],
tail_fraction: float = 0.1,
tolerance: float = 0.01,
) -> bool:
"""
Checks whether the magnitude of the slope of the tail of the series relative to the mean
is sufficiently small (below tolerance). Time can be provided to account for non-uniform
sampling intervals.
Raises:
AssertionError: If there are not enough measurements to assess equilibration.
"""
window_len = int(tail_fraction * len(values))
assert (
len(values) >= window_len and window_len >= 2
), f"Not enough measurements ({window_len}) to assess equilibration"
times = times[-window_len:] if times is not None else list(range(window_len))
values = values[-window_len:]
slope, _ = np.polyfit(times, values, deg=1)
return abs(slope / np.mean(values)) <= tolerance
[docs]
def str_table(rows: list[list], header: Optional[list] = None) -> str:
"""Format rows into a table with aligned columns."""
all_rows = [header] + rows if header else rows
num_cols = len(all_rows[0])
col_widths = [max(len(str(row[i])) for row in all_rows) for i in range(num_cols)]
formatted_rows = []
for i, row in enumerate(all_rows):
formatted_rows.append(
" | ".join(f"{str(item):<{col_widths[j]}}" for j, item in enumerate(row))
)
if i == 0 and header:
formatted_rows.append("-" * len(formatted_rows[-1]))
return "\n".join(formatted_rows)
[docs]
def rejection_sample(population: Iterable, excluded: Iterable, max_attempts: int = 100):
"""Randomly sample an element from `population` that is not in `excluded`."""
population = list(population)
if not population:
raise ValueError("Sequence is empty")
excluded_ids = set(id(x) for x in excluded)
# Fast rejection sampling (O(1) average case for small exclusion sets)
for _ in range(max_attempts):
choice = random.choice(population)
if id(choice) not in excluded_ids:
return choice
# Fallback to O(n) scan only if necessary (rare for small exclusion sets)
valid_choices = [item for item in population if id(item) not in excluded_ids]
if not valid_choices:
raise ValueError("No valid elements to choose from")
return random.choice(valid_choices)
[docs]
class OrderedSet[T]:
def __init__(self, items: Optional[Iterable[T]] = None):
self.dict = dict() if items is None else dict.fromkeys(items)
def __iter__(self):
yield from self.dict
def __len__(self):
return len(self.dict)
[docs]
def add(self, item: Any) -> None:
self.dict[item] = None
[docs]
def remove(self, item: Any) -> None:
del self.dict[item]
[docs]
class Counted:
"""Assigns a unique integer ID to each instance, starting from 0."""
counter = 0
def __init__(self):
self.id = Counted.counter
Counted.counter += 1
def __hash__(self):
return self.id
def __eq__(self, other):
return hash(self) == hash(other)
T = TypeVar("T") # Member type of `IndexedSet`
Property = Callable[[T], Iterable[Hashable]] # Returns the property values of an item
[docs]
class IndexedSet(set[T], Generic[T]):
"""
A subclass of the built-in `set`, with support for indexing by arbitrary
properties of set members and integer indexing to enable random sampling.
Credit https://stackoverflow.com/a/15993515 for the integer indexing logic.
NOTE: Member ordering is not stable across insertions and deletions.
Example usage:
```
[...] # define a SportsTeam class
teams: IndexedSet[SportsTeam] = IndexedSet()
teams.create_index("name", lambda team: [team.name])
teams.create_index("color", lambda team: [team.jersey_color])
[...] # populate the set with teams
teams.lookup_one("name", "Manchester") # Returns the team whose name is "Manchester"
teams.lookup("color", "blue") # Returns all teams with blue jerseys
```
"""
_item_to_pos: dict[T, int]
_item_list: list[T]
properties: dict[str, Property]
indices: dict[str, defaultdict[Hashable, Self]]
def __init__(self, iterable: Iterable[T] = []):
iterable = list(iterable)
super().__init__(iterable)
self._item_list = iterable
self._item_to_pos = {item: i for (i, item) in enumerate(iterable)}
self.properties = {}
self.indices = {}
def __getitem__(self, i):
assert 0 <= i < len(self)
return self._item_list[i]
[docs]
def create_index(self, name: str, prop: Property):
"""Create an index that's updated when adding and removing members.
Note:
Mutating set members outside of interface calls can invalidate indices.
"""
assert name not in self.properties
self.properties[name] = prop
self.indices[name] = defaultdict(IndexedSet)
for el in self:
for val in prop(el):
self.indices[name][val].add(el)
[docs]
def add(self, item: T):
if item in self:
return
super().add(item)
self._item_list.append(item)
self._item_to_pos[item] = len(self._item_list) - 1
self._update_indices_for_item(item, adding=True)
[docs]
def remove(self, item: T):
assert item in self
super().remove(item)
pos = self._item_to_pos.pop(item)
last_item = self._item_list.pop()
if pos < len(self._item_list):
self._item_list[pos] = last_item
self._item_to_pos[last_item] = pos
self._update_indices_for_item(item, adding=False)
def _update_indices_for_item(self, item: T, adding: bool):
"""Update property indices when adding or removing an item."""
for prop_name, prop in self.properties.items():
for val in prop(item):
index = self.indices[prop_name][val]
if adding:
index.add(item)
else:
index.remove(item)
if not index:
del self.indices[prop_name][val]
[docs]
def remove_by(self, prop_name: str, value: Any):
"""Remove all set members whose given property matches `value`."""
if value in self.indices[prop_name]:
for match in list(self.indices[prop_name][value]):
assert match in self
self.remove(match)
[docs]
def lookup(self, name: str, value: Any) -> Self:
"""Return an IndexedSet of all matching items."""
return self.indices[name][value]
[docs]
def lookup_one(self, name: str, value: Any) -> T:
"""Return a single matching item. Raises if not exactly one match."""
matches = self.indices[name][value]
assert len(matches) == 1
return next(iter(matches))