"""
Concrete implementation of the agents collection for mesa-frames.
This module provides the concrete implementation of the agents collection class
for the mesa-frames library. It defines the AgentSetRegistry class, which serves as a
container for all agent sets in a model, leveraging DataFrame-based storage for
improved performance.
Classes:
AgentSetRegistry(AbstractAgentSetRegistry):
A collection of AgentSets. This class acts as a container for all
agents in the model, organizing them into separate AgentSet instances
based on their types.
The AgentSetRegistry class is designed to be used within Model instances to manage
all agents in the simulation. It provides methods for adding, removing, and
accessing agents and agent sets, while taking advantage of the performance
benefits of DataFrame-based agent storage.
Usage:
The AgentSetRegistry class is typically instantiated and used within a Model subclass:
from mesa_frames.concrete.model import Model
from mesa_frames.concrete.agents import AgentSetRegistry
from mesa_frames.concrete import AgentSet
class MyCustomModel(Model):
def __init__(self):
super().__init__()
# Adding agent sets to the collection
self.sets += AgentSet(self)
self.sets += AnotherAgentSet(self)
def step(self):
# Step all agent sets
self.sets.do("step")
Note:
This concrete implementation builds upon the abstract AgentSetRegistry class
defined in the mesa_frames.abstract package, providing a ready-to-use
agents collection that integrates with the DataFrame-based agent storage system.
For more detailed information on the AgentSetRegistry class and its methods, refer to
the class docstring.
"""
from __future__ import annotations # For forward references
from collections.abc import Collection, Iterable, Iterator, Sequence
from typing import Any, Literal, Self, overload, cast
from collections.abc import Sized
from itertools import chain
import polars as pl
from mesa_frames.abstract.agentsetregistry import (
AbstractAgentSetRegistry,
)
from mesa_frames.concrete.agentset import AgentSet
from mesa_frames.types_ import BoolSeries, KeyBy, AgentSetSelector
[docs]
class AgentSetRegistry(AbstractAgentSetRegistry):
"""A collection of AgentSets. All agents of the model are stored here."""
_agentsets: list[AgentSet]
_ids: pl.Series
[docs]
def __init__(self, model: mesa_frames.concrete.model.Model) -> None:
"""Initialize a new AgentSetRegistry.
Parameters
----------
model : mesa_frames.concrete.model.Model
The model associated with the AgentSetRegistry.
"""
self._model = model
self._agentsets = []
self._ids = pl.Series(name="unique_id", dtype=pl.UInt64)
[docs]
def add(
self,
sets: AgentSet | Iterable[AgentSet],
inplace: bool = True,
) -> Self:
obj = self._get_obj(inplace)
other_list = obj._return_agentsets_list(sets)
if obj._check_agentsets_presence(other_list).any():
raise ValueError(
"Some agentsets are already present in the AgentSetRegistry."
)
# Ensure unique names across existing and to-be-added sets
existing_names = {s.name for s in obj._agentsets}
for agentset in other_list:
base_name = agentset.name or agentset.__class__.__name__
name = base_name
if name in existing_names:
counter = 1
candidate = f"{base_name}_{counter}"
while candidate in existing_names:
counter += 1
candidate = f"{base_name}_{counter}"
name = candidate
# Assign back if changed or was None
if name != (agentset.name or base_name):
agentset.name = name
existing_names.add(name)
new_ids = pl.concat(
[obj._ids] + [pl.Series(agentset["unique_id"]) for agentset in other_list]
)
if new_ids.is_duplicated().any():
raise ValueError("Some of the agent IDs are not unique.")
obj._agentsets.extend(other_list)
obj._ids = new_ids
return obj
[docs]
def rename(
self,
target: (
AgentSet
| str
| dict[AgentSet | str, str]
| list[tuple[AgentSet | str, str]]
),
new_name: str | None = None,
*,
on_conflict: Literal["canonicalize", "raise"] = "canonicalize",
mode: Literal["atomic", "best_effort"] = "atomic",
inplace: bool = True,
) -> Self:
"""Rename AgentSets with conflict handling.
Supports single-target ``(set | old_name, new_name)`` and batch rename via
dict or list of pairs. Names remain unique across the registry.
"""
# Normalize to list of (index_in_self, desired_name) using the original registry
def _resolve_one(x: AgentSet | str) -> int:
if isinstance(x, AgentSet):
for i, s in enumerate(self._agentsets):
if s is x:
return i
raise KeyError("AgentSet not found in registry")
# name lookup on original registry
for i, s in enumerate(self._agentsets):
if s.name == x:
return i
raise KeyError(f"Agent set '{x}' not found")
if isinstance(target, (AgentSet, str)):
if new_name is None:
raise TypeError("new_name must be provided for single rename")
pairs_idx: list[tuple[int, str]] = [(_resolve_one(target), new_name)]
single = True
elif isinstance(target, dict):
pairs_idx = [(_resolve_one(k), v) for k, v in target.items()]
single = False
else:
pairs_idx = [(_resolve_one(k), v) for k, v in target]
single = False
# Choose object to mutate
obj = self._get_obj(inplace)
# Translate indices to object AgentSets in the selected registry object
target_sets = [obj._agentsets[i] for i, _ in pairs_idx]
# Build the set of names that remain fixed (exclude targets' current names)
targets_set = set(target_sets)
fixed_names: set[str] = {
s.name
for s in obj._agentsets
if s.name is not None and s not in targets_set
} # type: ignore[comparison-overlap]
# Plan final names
final: list[tuple[AgentSet, str]] = []
used = set(fixed_names)
def _canonicalize(base: str) -> str:
if base not in used:
used.add(base)
return base
counter = 1
cand = f"{base}_{counter}"
while cand in used:
counter += 1
cand = f"{base}_{counter}"
used.add(cand)
return cand
errors: list[Exception] = []
for aset, (_idx, desired) in zip(target_sets, pairs_idx):
if on_conflict == "canonicalize":
final_name = _canonicalize(desired)
final.append((aset, final_name))
else: # on_conflict == 'raise'
if desired in used:
err = ValueError(
f"Duplicate agent set name disallowed: '{desired}'"
)
if mode == "atomic":
errors.append(err)
else:
# best_effort: skip this rename
continue
else:
used.add(desired)
final.append((aset, desired))
if errors and mode == "atomic":
# Surface first meaningful error
raise errors[0]
# Apply renames
for aset, newn in final:
# Set the private name directly to avoid external uniqueness hooks
if hasattr(aset, "_name"):
aset._name = newn # type: ignore[attr-defined]
return obj
[docs]
def replace(
self,
mapping: (dict[int | str, AgentSet] | list[tuple[int | str, AgentSet]]),
*,
inplace: bool = True,
atomic: bool = True,
) -> Self:
# Normalize to list of (key, value)
items: list[tuple[int | str, AgentSet]]
if isinstance(mapping, dict):
items = list(mapping.items())
else:
items = list(mapping)
obj = self._get_obj(inplace)
# Helpers (build name->idx map only if needed)
has_str_keys = any(isinstance(k, str) for k, _ in items)
if has_str_keys:
name_to_idx = {
s.name: i for i, s in enumerate(obj._agentsets) if s.name is not None
}
def _find_index_by_name(name: str) -> int:
try:
return name_to_idx[name]
except KeyError:
raise KeyError(f"Agent set '{name}' not found")
else:
def _find_index_by_name(name: str) -> int:
for i, s in enumerate(obj._agentsets):
if s.name == name:
return i
raise KeyError(f"Agent set '{name}' not found")
if atomic:
n = len(obj._agentsets)
# Map existing object identity -> index (for aliasing checks)
id_to_idx = {id(s): i for i, s in enumerate(obj._agentsets)}
for k, v in items:
if not isinstance(v, AgentSet):
raise TypeError("Values must be AgentSet instances")
if v.model is not obj.model:
raise TypeError(
"All AgentSets must belong to the same model as the registry"
)
v_idx_existing = id_to_idx.get(id(v))
if isinstance(k, int):
if not (0 <= k < n):
raise IndexError(
f"Index {k} out of range for AgentSetRegistry of size {n}"
)
# Prevent aliasing: the same object cannot appear in two positions
if v_idx_existing is not None and v_idx_existing != k:
raise ValueError(
f"This AgentSet instance already exists at index {v_idx_existing}; cannot also place it at {k}."
)
# Preserve name uniqueness when assigning by index
vname = v.name
if vname is not None:
try:
other_idx = _find_index_by_name(vname)
if other_idx != k:
raise ValueError(
f"Duplicate agent set name disallowed: '{vname}' already at index {other_idx}"
)
except KeyError:
# name not present elsewhere -> OK
pass
elif isinstance(k, str):
# Locate the slot by name; replacing that slot preserves uniqueness
idx = _find_index_by_name(k)
# Prevent aliasing: if the same object already exists at a different slot, forbid
if v_idx_existing is not None and v_idx_existing != idx:
raise ValueError(
f"This AgentSet instance already exists at index {v_idx_existing}; cannot also place it at {idx}."
)
else:
raise TypeError("Keys must be int indices or str names")
# Apply
target = obj if inplace else obj.copy(deep=False)
if not inplace:
target._agentsets = list(obj._agentsets)
for k, v in items:
if isinstance(k, int):
target._agentsets[k] = v # keep v.name as-is (validated above)
else:
idx = _find_index_by_name(k)
# Force the authoritative name without triggering external uniqueness checks
if hasattr(v, "_name"):
v._name = k # type: ignore[attr-defined]
target._agentsets[idx] = v
# Recompute ids cache
target._recompute_ids()
return target
@overload
def contains(self, sets: AgentSet | type[AgentSet] | str) -> bool: ...
@overload
def contains(
self,
sets: Iterable[AgentSet] | Iterable[type[AgentSet]] | Iterable[str],
) -> pl.Series: ...
[docs]
def contains(
self,
sets: AgentSet
| type[AgentSet]
| str
| Iterable[AgentSet]
| Iterable[type[AgentSet]]
| Iterable[str],
) -> bool | pl.Series:
# Single value fast paths
if isinstance(sets, AgentSet):
return self._check_agentsets_presence([sets]).any()
if isinstance(sets, type) and issubclass(sets, AgentSet):
return any(isinstance(s, sets) for s in self._agentsets)
if isinstance(sets, str):
return any(s.name == sets for s in self._agentsets)
# Iterable paths without materializing unnecessarily
if isinstance(sets, Sized) and len(sets) == 0: # type: ignore[arg-type]
return True
it = iter(sets) # type: ignore[arg-type]
try:
first = next(it)
except StopIteration:
return True
if isinstance(first, AgentSet):
lst = [first, *it]
return self._check_agentsets_presence(lst)
if isinstance(first, type) and issubclass(first, AgentSet):
present_types = {type(s) for s in self._agentsets}
def has_type(t: type[AgentSet]) -> bool:
return any(issubclass(pt, t) for pt in present_types)
return pl.Series(
(has_type(t) for t in chain([first], it)), dtype=pl.Boolean
)
if isinstance(first, str):
names = {s.name for s in self._agentsets if s.name is not None}
return pl.Series((x in names for x in chain([first], it)), dtype=pl.Boolean)
raise TypeError("Unsupported type for contains()")
@overload
def do(
self,
method_name: str,
*args: Any,
sets: AgentSetSelector | None = None,
return_results: Literal[False] = False,
inplace: bool = True,
key_by: KeyBy = "name",
**kwargs: Any,
) -> Self: ...
@overload
def do(
self,
method_name: str,
*args: Any,
sets: AgentSetSelector,
return_results: Literal[True],
inplace: bool = True,
key_by: KeyBy = "name",
**kwargs: Any,
) -> dict[str, Any] | dict[int, Any] | dict[type[AgentSet], Any]: ...
[docs]
def do(
self,
method_name: str,
*args: Any,
sets: AgentSetSelector = None,
return_results: bool = False,
inplace: bool = True,
key_by: KeyBy = "name",
**kwargs: Any,
) -> Self | Any:
obj = self._get_obj(inplace)
target_sets = obj._resolve_selector(sets)
if not target_sets:
return {} if return_results else obj
index_lookup = {id(s): idx for idx, s in enumerate(obj._agentsets)}
if return_results:
def make_key(agentset: AgentSet) -> Any:
if key_by == "name":
return agentset.name
if key_by == "index":
try:
return index_lookup[id(agentset)]
except KeyError as exc: # pragma: no cover - defensive
raise ValueError(
"AgentSet not found in registry; cannot key by index."
) from exc
if key_by == "type":
return type(agentset)
return agentset # backward-compatible: key by object
results: dict[Any, Any] = {}
for agentset in target_sets:
key = make_key(agentset)
if key_by == "type" and key in results:
raise ValueError(
"Multiple agent sets of the same type were selected; "
"use key_by='name' or key_by='index' instead."
)
results[key] = agentset.do(
method_name,
*args,
return_results=True,
inplace=inplace,
**kwargs,
)
return results
updates: list[tuple[int, AgentSet]] = []
for agentset in target_sets:
try:
registry_index = index_lookup[id(agentset)]
except KeyError as exc: # pragma: no cover - defensive
raise ValueError(
"AgentSet not found in registry; cannot apply operation."
) from exc
updated = agentset.do(
method_name,
*args,
return_results=False,
inplace=inplace,
**kwargs,
)
updates.append((registry_index, updated))
for registry_index, updated in updates:
obj._agentsets[registry_index] = updated
obj._recompute_ids()
return obj
@overload
def get(self, key: int, default: None = ...) -> AgentSet | None: ...
@overload
def get(self, key: str, default: None = ...) -> AgentSet | None: ...
@overload
def get(self, key: type[AgentSet], default: None = ...) -> list[AgentSet]: ...
@overload
def get(
self,
key: int | str | type[AgentSet],
default: AgentSet | list[AgentSet] | None,
) -> AgentSet | list[AgentSet] | None: ...
[docs]
def get(
self,
key: int | str | type[AgentSet],
default: AgentSet | list[AgentSet] | None = None,
) -> AgentSet | list[AgentSet] | None:
try:
if isinstance(key, int):
return self._agentsets[key]
if isinstance(key, str):
for s in self._agentsets:
if s.name == key:
return s
return default
if isinstance(key, type) and issubclass(key, AgentSet):
return [s for s in self._agentsets if isinstance(s, key)]
except (IndexError, KeyError, TypeError):
return default
return default
[docs]
def remove(
self,
sets: AgentSetSelector,
inplace: bool = True,
) -> Self:
obj = self._get_obj(inplace)
# Normalize to a list of AgentSet instances using _resolve_selector
selected = obj._resolve_selector(sets) # type: ignore[arg-type]
# Drop agents from space before detaching their sets from the registry
if self.model._space is not None:
self.model.space.remove_agents(selected)
# Remove in reverse positional order
indices = [i for i, s in enumerate(obj._agentsets) if s in selected]
indices.sort(reverse=True)
for idx in indices:
obj._agentsets.pop(idx)
# Recompute ids cache
obj._recompute_ids()
return obj
[docs]
def shuffle(self, inplace: bool = False) -> Self:
obj = self._get_obj(inplace)
obj._agentsets = [agentset.shuffle(inplace=True) for agentset in obj._agentsets]
return obj
[docs]
def sort(
self,
by: str | Sequence[str],
ascending: bool | Sequence[bool] = True,
inplace: bool = True,
**kwargs: Any,
) -> Self:
obj = self._get_obj(inplace)
obj._agentsets = [
agentset.sort(by=by, ascending=ascending, inplace=inplace, **kwargs)
for agentset in obj._agentsets
]
return obj
def _check_ids_presence(self, other: list[AgentSet]) -> pl.DataFrame:
"""Check if the IDs of the agents to be added are unique.
Parameters
----------
other : list[AgentSet]
The AgentSets to check.
Returns
-------
pl.DataFrame
A DataFrame with the unique IDs and a boolean column indicating if they are present.
"""
presence_df = pl.DataFrame(
data={"unique_id": self._ids, "present": True},
schema={"unique_id": pl.UInt64, "present": pl.Boolean},
)
for agentset in other:
new_ids = pl.Series(agentset.index, dtype=pl.UInt64)
presence_df = pl.concat(
[
presence_df,
(
new_ids.is_in(presence_df["unique_id"])
.to_frame("present")
.with_columns(unique_id=new_ids)
.select(["unique_id", "present"])
),
]
)
presence_df = presence_df.slice(self._ids.len())
return presence_df
def _check_agentsets_presence(self, other: list[AgentSet]) -> pl.Series:
"""Check if the agent sets to be added are already present in the AgentSetRegistry.
Parameters
----------
other : list[AgentSet]
The AgentSets to check.
Returns
-------
pl.Series
A boolean Series indicating if the agent sets are present.
Raises
------
ValueError
If the agent sets are already present in the AgentSetRegistry.
"""
other_set = set(other)
return pl.Series(
[agentset in other_set for agentset in self._agentsets], dtype=pl.Boolean
)
def _recompute_ids(self) -> None:
"""Rebuild the registry-level `unique_id` cache from current AgentSets.
Ensures `self._ids` stays a `pl.UInt64` Series and empty when no sets.
"""
if self._agentsets:
cols = [pl.Series(s["unique_id"]) for s in self._agentsets]
self._ids = (
pl.concat(cols)
if cols
else pl.Series(name="unique_id", dtype=pl.UInt64)
)
else:
self._ids = pl.Series(name="unique_id", dtype=pl.UInt64)
def _resolve_selector(self, selector: AgentSetSelector = None) -> list[AgentSet]:
"""Resolve a selector (instance/type/name or collection) to a list of AgentSets."""
if selector is None:
return list(self._agentsets)
# Single instance
if isinstance(selector, AgentSet):
return [selector] if selector in self._agentsets else []
# Single type
if isinstance(selector, type) and issubclass(selector, AgentSet):
return [s for s in self._agentsets if isinstance(s, selector)]
# Single name
if isinstance(selector, str):
return [s for s in self._agentsets if s.name == selector]
# Collection of mixed selectors
selected: list[AgentSet] = []
for item in selector: # type: ignore[assignment]
if isinstance(item, AgentSet):
if item in self._agentsets:
selected.append(item)
elif isinstance(item, type) and issubclass(item, AgentSet):
selected.extend([s for s in self._agentsets if isinstance(s, item)])
elif isinstance(item, str):
selected.extend([s for s in self._agentsets if s.name == item])
else:
raise TypeError("Unsupported selector element type")
# Deduplicate while preserving order
seen = set()
result = []
for s in selected:
if s not in seen:
seen.add(s)
result.append(s)
return result
def _return_agentsets_list(
self, agentsets: AgentSet | Iterable[AgentSet]
) -> list[AgentSet]:
"""Convert the agentsets to a list of AgentSet.
Parameters
----------
agentsets : AgentSet | Iterable[AgentSet]
Returns
-------
list[AgentSet]
"""
return [agentsets] if isinstance(agentsets, AgentSet) else list(agentsets)
def _generate_name(self, base_name: str) -> str:
"""Generate a unique name for an agent set."""
existing_names = [
agentset.name for agentset in self._agentsets if agentset.name is not None
]
if base_name not in existing_names:
return base_name
counter = 1
candidate = f"{base_name}_{counter}"
while candidate in existing_names:
counter += 1
candidate = f"{base_name}_{counter}"
return candidate
[docs]
def __getattr__(self, name: str) -> Any | dict[str, Any]:
# Avoids infinite recursion of private attributes
if name.startswith("_"):
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
# Delegate attribute access to sets; map results by set name
return {cast(str, s.name): getattr(s, name) for s in self._agentsets}
[docs]
def __iter__(self) -> Iterator[AgentSet]:
return iter(self._agentsets)
[docs]
def __len__(self) -> int:
return len(self._agentsets)
[docs]
def __repr__(self) -> str:
return "\n".join([repr(agentset) for agentset in self._agentsets])
[docs]
def __reversed__(self) -> Iterator[AgentSet]:
return reversed(self._agentsets)
[docs]
def __str__(self) -> str:
return "\n".join([str(agentset) for agentset in self._agentsets])
@property
def ids(self) -> pl.Series:
"""Public view of all agent unique_id values across contained sets."""
return self._ids
@overload
def __getitem__(self, key: int) -> AgentSet: ...
@overload
def __getitem__(self, key: str) -> AgentSet: ...
@overload
def __getitem__(self, key: type[AgentSet]) -> list[AgentSet]: ...
[docs]
def __getitem__(self, key: int | str | type[AgentSet]) -> AgentSet | list[AgentSet]:
"""Retrieve AgentSet(s) by index, name, or type."""
if isinstance(key, int):
return self._agentsets[key]
if isinstance(key, str):
for s in self._agentsets:
if s.name == key:
return s
raise KeyError(f"Agent set '{key}' not found")
if isinstance(key, type) and issubclass(key, AgentSet):
return [s for s in self._agentsets if isinstance(s, key)]
raise TypeError("Key must be int, str (name), or AgentSet type")