from dataclasses import dataclass, field
from typing import Optional, Iterable, Iterator, Self
from contextlib import contextmanager
from pykappa.pattern import Site, Agent, Component, Pattern, Embedding
from pykappa.utils import IndexedSet
[docs]
@dataclass(frozen=True)
class Edge:
"""Represents bonds between sites. Edge(x, y) equals Edge(y, x)."""
site1: Site
site2: Site
def __eq__(self, other):
return (self.site1 == other.site1 and self.site2 == other.site2) or (
self.site1 == other.site2 and self.site2 == other.site1
)
def __hash__(self):
return hash(frozenset((self.site1, self.site2)))
[docs]
class Mixture:
"""A collection of agents and their connections.
Optionally tracks connected components, enabled via `enable_component_tracking()`.
Attributes:
agents: Indexed set of all agents in the mixture.
_components: Indexed set of components if enabled, else None.
_embeddings: Cache of embeddings for tracked components.
_max_embedding_width: Maximum diameter of tracked components.
"""
agents: IndexedSet[Agent]
_components: Optional[IndexedSet[Component]]
_embeddings: dict[Component, IndexedSet[Embedding]]
_max_embedding_width: int
[docs]
@classmethod
def from_kappa(cls, patterns: dict[str, int]) -> Self:
"""Create a mixture from Kappa pattern strings and counts.
Args:
patterns: Dictionary mapping pattern strings to copy counts.
"""
real_patterns = []
for pattern, count in patterns.items():
real_patterns.extend([Pattern.from_kappa(pattern)] * count)
return cls(real_patterns)
def __init__(
self,
patterns: Optional[Iterable[Pattern]] = None,
track_components: bool = False,
):
self.agents = IndexedSet()
self.agents.create_index("type", lambda a: [a.type])
self._components = None
self._embeddings = {}
self._max_embedding_width = 0
if track_components:
self.enable_component_tracking()
if patterns is not None:
for pattern in patterns:
self.instantiate(pattern)
def __iter__(self) -> Iterator[Component]:
yield from self.components
@property
def kappa_str(self) -> str:
"""The mixture in Kappa format with %init declarations."""
return "\n".join(
f"%init: {len(components)} {group.kappa_str}"
for group, components in group_by_isomorphism(
list(component for component in self)
).items()
)
@property
def component_tracking(self) -> bool:
"""Whether connected components are being tracked."""
return self._components is not None
@property
def components(self) -> IndexedSet[Component]:
if self.component_tracking: # Use cached components if tracking
return self._components
else: # Find connected components among the existing agents
components = IndexedSet()
unassigned = set(self.agents)
while unassigned:
seed = next(iter(unassigned))
component_agents = set(seed.depth_first_traversal)
component_agents.intersection_update(self.agents)
components.add(Component(component_agents))
unassigned.difference_update(component_agents)
return components
[docs]
def enable_component_tracking(self) -> None:
"""Turn on connected-component tracking for this mixture."""
if self.component_tracking:
return
self._components = IndexedSet(self.components)
self._components.create_index("agent", lambda c: c.agents)
# If embeddings are already tracked, add a component index to them
for embset in self._embeddings.values():
embset.create_index(
"component",
lambda e: [self.components.lookup_one("agent", next(iter(e.values())))],
)
[docs]
def instantiate(self, pattern: Pattern | str, n_copies: int = 1) -> None:
"""Add instances of a pattern to the mixture.
Args:
pattern: Pattern to instantiate, or Kappa string.
n_copies: Number of copies to create.
Raises:
AssertionError: If pattern is underspecified.
"""
if isinstance(pattern, str):
pattern = Pattern.from_kappa(pattern)
assert (
not pattern.underspecified
), "Pattern isn't specific enough to instantiate."
for _ in range(n_copies):
for component in pattern.components:
self.add(component)
[docs]
def add(self, component: Component) -> None:
"""Add a component to the mixture."""
component_ordered = list(component.agents)
new_agents = [agent.detached() for agent in component_ordered]
new_edges = set()
for i, agent in enumerate(component_ordered):
# Duplicate the proper link structure
for site in agent:
if site.coupled:
partner = site.partner
i_partner = component_ordered.index(partner.agent)
new_site = new_agents[i][site.label]
new_partner = new_agents[i_partner][partner.label]
new_edges.add(Edge(new_site, new_partner))
update = MixtureUpdate(agents_to_add=set(new_agents), edges_to_add=new_edges)
self.apply_update(update)
[docs]
def remove(self, component: Component) -> None:
"""Remove a component from the mixture."""
update = MixtureUpdate()
for agent in component:
update.remove_agent(agent)
self.apply_update(update)
[docs]
def embeddings(self, component: Component) -> IndexedSet[Embedding]:
"""Get embeddings of a tracked component (not accounting for symmetries).
Raises:
KeyError: If component is not being tracked.
"""
try:
return self._embeddings[component]
except KeyError as e:
e.add_note(
f"Undeclared component: {component}. To embed it, first use `track_component`."
)
raise
[docs]
def embeddings_in_component(
self, match_pattern: Component, mixture_component: Component
) -> IndexedSet[Embedding]:
"""Get embeddings of a pattern within a specific component."""
if not self.component_tracking:
raise RuntimeError("Component tracking is not enabled.")
return self._embeddings[match_pattern].lookup("component", mixture_component)
[docs]
def track_component(self, component: Component):
"""Start tracking embeddings of a component."""
self._max_embedding_width = max(component.diameter, self._max_embedding_width)
embeddings = IndexedSet(component.embeddings(self))
embeddings.create_index("agent", lambda e: iter(e.values()))
self._embeddings[component] = embeddings
if self.component_tracking:
embeddings.create_index(
"component",
lambda e: [self.components.lookup_one("agent", next(iter(e.values())))],
)
[docs]
def apply_update(self, update: "MixtureUpdate") -> None:
"""Apply a collection of changes to the mixture."""
for agent in update.touched_before:
for tracked in self._embeddings:
self._embeddings[tracked].remove_by("agent", agent)
for edge in update.edges_to_remove:
self._remove_edge(edge)
for agent in update.agents_to_remove:
self._remove_agent(agent)
for agent in update.agents_to_add:
self._add_agent(agent)
for edge in update.edges_to_add:
self._add_edge(edge)
update_region = neighborhood(update.touched_after, self._max_embedding_width)
update_region = IndexedSet(update_region)
update_region.create_index("type", lambda a: [a.type])
for component_pattern in self._embeddings:
new_embeddings = component_pattern.embeddings(update_region)
for e in new_embeddings:
self._embeddings[component_pattern].add(e)
def _add_agent(self, agent: Agent) -> None:
"""Add an agent to the mixture (should not have any bound sites)."""
assert all(site.partner == "." for site in agent)
assert agent.instantiable
self.agents.add(agent)
if self.component_tracking:
self.components.add(Component([agent]))
def _remove_agent(self, agent: Agent) -> None:
"""Remove an agent from the mixture (bonds must be removed first)."""
assert all(site.partner == "." for site in agent)
self.agents.remove(agent)
if self.component_tracking:
component = self.components.lookup_one("agent", agent)
self.components.remove(component)
def _add_edge(self, edge: Edge) -> None:
"""Add a bond between two sites."""
assert edge.site1.agent in self.agents
assert edge.site2.agent in self.agents
edge.site1.partner = edge.site2
edge.site2.partner = edge.site1
if not self.component_tracking:
return
component1 = self.components.lookup_one("agent", edge.site1.agent)
component2 = self.components.lookup_one("agent", edge.site2.agent)
if component1 == component2:
return
if len(component2) > len(component1):
component1, component2 = component2, component1
with self._relocate_embeddings(component2):
self.components.remove(component2)
for agent in component2:
component1.add(agent)
self.components.indices["agent"][agent] = [component1]
def _remove_edge(self, edge: Edge) -> None:
"""Remove a bond between two sites."""
assert edge.site1.partner == edge.site2
assert edge.site2.partner == edge.site1
edge.site1.partner = "."
edge.site2.partner = "."
if not self.component_tracking:
return
agent1: Agent = edge.site1.agent
agent2: Agent = edge.site2.agent
old_component = self.components.lookup_one("agent", agent1)
assert old_component == self.components.lookup_one("agent", agent2)
maybe_new_component = Component(agent1.depth_first_traversal)
if agent2 in maybe_new_component:
return
new_component1 = maybe_new_component
new_component2 = Component(agent2.depth_first_traversal)
with self._relocate_embeddings(old_component):
self.components.remove(old_component)
self.components.add(new_component1)
self.components.add(new_component2)
@contextmanager
def _relocate_embeddings(self, component: Component):
"""Temporarily evacuate and restore embeddings during component restructuring."""
relocated = {}
for tracked in self._embeddings:
relocated[tracked] = list(
self._embeddings[tracked].lookup("component", component)
)
for e in relocated[tracked]:
self._embeddings[tracked].remove(e)
try:
yield
finally:
for tracked in self._embeddings:
for e in relocated.get(tracked, []):
self._embeddings[tracked].add(e)
[docs]
@dataclass
class MixtureUpdate:
"""Specifies changes to be applied to a mixture."""
agents_to_add: set[Agent] = field(default_factory=set)
agents_to_remove: set[Agent] = field(default_factory=set)
edges_to_add: set[Edge] = field(default_factory=set)
edges_to_remove: set[Edge] = field(default_factory=set)
agents_changed: set[Agent] = field(default_factory=set) # Internal state changes
[docs]
def create_agent(self, agent: Agent) -> Agent:
"""Create a new agent based on a template (sites will be emptied)."""
new_agent = agent.detached()
self.agents_to_add.add(new_agent)
return new_agent
[docs]
def remove_agent(self, agent: Agent) -> None:
"""Specify to remove an agent and its edges from the mixture."""
self.agents_to_remove.add(agent)
for site in agent:
if site.coupled:
self.edges_to_remove.add(Edge(site, site.partner))
[docs]
def connect_sites(self, site1: Site, site2: Site) -> None:
"""Specify to create an edge between two sites. If the sites
are bound to other sites, indicates to remove those edges.
"""
if site1.coupled and site1.partner != site2:
self.disconnect_site(site1)
if site2.coupled and site2.partner != site1:
self.disconnect_site(site2)
if not site1.partner == site2:
self.edges_to_add.add(Edge(site1, site2))
[docs]
def disconnect_site(self, site: Site) -> None:
"""Specify that a site should be unbound."""
if site.coupled:
self.edges_to_remove.add(Edge(site, site.partner))
[docs]
def register_changed_agent(self, agent: Agent) -> None:
"""Register an agent as having internal state changes."""
self.agents_changed.add(agent)
@property
def touched_before(self) -> set[Agent]:
"""The agents that will be changed or removed by this update."""
touched = self.agents_changed | set(self.agents_to_remove)
for edge in self.edges_to_remove:
touched.add(edge.site1.agent)
touched.add(edge.site2.agent)
for edge in self.edges_to_add:
a, b = edge.site1.agent, edge.site2.agent
if a not in self.agents_to_add:
touched.add(a)
if b not in self.agents_to_add:
touched.add(b)
return touched
@property
def touched_after(self) -> set[Agent]:
"""The agents that will be changed or added after this update."""
touched = self.agents_changed | set(self.agents_to_add)
for edge in self.edges_to_add:
touched.add(edge.site1.agent)
touched.add(edge.site2.agent)
for edge in self.edges_to_remove:
a, b = edge.site1.agent, edge.site2.agent
if a not in self.agents_to_remove:
touched.add(a)
if b not in self.agents_to_remove:
touched.add(b)
return touched
[docs]
def neighborhood(agents: Iterable[Agent], radius: int) -> set[Agent]:
"""Get all agents within a distance radius of the given agents."""
frontier = set(agents)
seen = set(frontier)
for _ in range(radius):
next_frontier = set()
for cur in frontier:
for n in cur.neighbors:
if n not in seen:
seen.add(n)
next_frontier.add(n)
frontier = next_frontier
if not frontier:
break
return seen
[docs]
def group_by_isomorphism(
components: Iterable[Component],
) -> dict[Component, list[Component]]:
"""Group components by isomorphism.
Returns:
Dictionary mapping representative components to lists of isomorphic components.
"""
grouped: dict[Component, list[Component]] = {}
for component in components:
for group in grouped:
if component.isomorphic(group):
grouped[group].append(component)
break
else:
grouped[component] = [component]
return grouped