Skip to content

Commit

Permalink
Randomness Manager Typing (#523)
Browse files Browse the repository at this point in the history
* randomness manager

* get around circular import

* make it attr/ prop double and single underscore

* finish setting clock and key_mapping

* whoops

* change to _clock_

* Update CHANGELOG.rst

* Update event.py
  • Loading branch information
patricktnast authored Oct 31, 2024
1 parent 9baf3f1 commit f376f2b
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 21 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
**3.0.16 - 10/30/24**
**3.0.16 - 10/31/24**

- Bugfix to prevent a LookupTable from changing order of the value columns
- Fix mypy errors in vivarium/framework/lookup/table.py
- Fix mypy errors in vivarium/framework/randomness/manager.py
- Fix mypy errors in vivarium/interface/utilities.py
- Typing changes in vivarium/framework/lookup/interpolation.py
- Fix broken build from LayeredConfigTree typing
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ exclude = [
'src/vivarium/framework/lookup/manager.py',
'src/vivarium/framework/population/manager.py',
'src/vivarium/framework/population/population_view.py',
'src/vivarium/framework/randomness/manager.py',
'src/vivarium/framework/results/context.py',
'src/vivarium/framework/results/interface.py',
'src/vivarium/framework/results/manager.py',
Expand Down
7 changes: 4 additions & 3 deletions src/vivarium/framework/randomness/index_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pandas.api.types as pdt

from vivarium.framework.randomness.exceptions import RandomnessError
from vivarium.types import ClockTime


class IndexMap:
Expand All @@ -35,7 +36,7 @@ def __init__(self, key_columns: list[str] | None = None, size: int = 1_000_000):
"""The mapping between the key columns and the randomness index."""
self._size = size

def update(self, new_keys: pd.DataFrame, clock_time: pd.Timestamp) -> None:
def update(self, new_keys: pd.DataFrame, clock_time: ClockTime) -> None:
"""Adds the new keys to the mapping.
Parameters
Expand Down Expand Up @@ -95,7 +96,7 @@ def _parse_new_keys(self, new_keys: pd.DataFrame) -> tuple[pd.Index[Any], pd.Ind
return new_mapping_index, final_mapping_index

def _build_final_mapping(
self, new_mapping_index: pd.Index[Any], clock_time: pd.Timestamp
self, new_mapping_index: pd.Index[Any], clock_time: ClockTime
) -> pd.Series[int]:
"""Builds a new mapping between key columns and the randomness index from the
new mapping index and the existing map.
Expand Down Expand Up @@ -154,7 +155,7 @@ def _resolve_collisions(
salt += 1
return current_mapping

def _hash(self, keys: pd.Index[Any], salt: int | pd.Timestamp = 0) -> pd.Series[int]:
def _hash(self, keys: pd.Index[Any], salt: ClockTime = 0) -> pd.Series[int]:
"""Hashes the index into an integer index in the range [0, self.stride]
Parameters
Expand Down
45 changes: 33 additions & 12 deletions src/vivarium/framework/randomness/manager.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
# mypy: ignore-errors
"""
=========================
Randomness System Manager
=========================
"""
from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING

import pandas as pd

from vivarium.framework.randomness.exceptions import RandomnessError
from vivarium.framework.randomness.index_map import IndexMap
from vivarium.framework.randomness.stream import RandomnessStream, get_hash
from vivarium.manager import Interface, Manager
from vivarium.types import ClockTime

if TYPE_CHECKING:
from vivarium.framework.engine import Builder


class RandomnessManager(Manager):
Expand All @@ -26,28 +33,42 @@ class RandomnessManager(Manager):
}
}

def __init__(self):
self._seed = None
self._clock = None
self._key_columns = None
self._key_mapping = None
self._decision_points = dict()
def __init__(self) -> None:
self._seed: str = ""
self._clock_: Callable[[], ClockTime] | None = None
self._key_columns: list[str] = []
self._key_mapping_: IndexMap | None = None
self._decision_points: dict[str, RandomnessStream] = dict()

@property
def name(self):
def name(self) -> str:
return "randomness_manager"

def setup(self, builder):
@property
def _clock(self) -> Callable[[], ClockTime]:
if self._clock_ is None:
raise RandomnessError("RandomnessManager clock was invoked before being set.")
return self._clock_

@property
def _key_mapping(self) -> IndexMap:
if self._key_mapping_ is None:
raise RandomnessError(
"RandomnessManager key_mapping was invoked before being set."
)
return self._key_mapping_

def setup(self, builder: Builder) -> None:
self._seed = str(builder.configuration.randomness.random_seed)
if builder.configuration.randomness.additional_seed is not None:
self._seed += str(builder.configuration.randomness.additional_seed)
self._clock = builder.time.clock()
self._clock_ = builder.time.clock()
self._key_columns = builder.configuration.randomness.key_columns

map_size = builder.configuration.randomness.map_size
pop_size = builder.configuration.population.population_size
map_size = max(map_size, 10 * pop_size)
self._key_mapping = IndexMap(self._key_columns, map_size)
self._key_mapping_ = IndexMap(self._key_columns, map_size)

self.resources = builder.resources
self._add_constraint = builder.lifecycle.add_constraint
Expand Down Expand Up @@ -175,7 +196,7 @@ def register_simulants(self, simulants: pd.DataFrame) -> None:
)
self._key_mapping.update(simulants.loc[:, self._key_columns], self._clock())

def __str__(self):
def __str__(self) -> str:
return "RandomnessManager()"

def __repr__(self) -> str:
Expand Down
10 changes: 6 additions & 4 deletions tests/framework/randomness/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def test_randomness_manager_get_randomness_stream():
rm = RandomnessManager()
rm._add_constraint = lambda f, **kwargs: f
rm._seed = seed
rm._clock = mock_clock
rm._clock_ = mock_clock
rm._key_columns = ["age", "sex"]
rm._key_mapping_ = IndexMap(["age", "sex"])
stream = rm._get_randomness_stream("test")

assert stream.key == "test"
Expand All @@ -33,9 +35,9 @@ def test_randomness_manager_register_simulants():
rm = RandomnessManager()
rm._add_constraint = lambda f, **kwargs: f
rm._seed = seed
rm._clock = mock_clock
rm._clock_ = mock_clock
rm._key_columns = ["age", "sex"]
rm._key_mapping = IndexMap(["age", "sex"])
rm._key_mapping_ = IndexMap(["age", "sex"])

bad_df = pd.DataFrame({"age": range(10), "not_sex": [1] * 5 + [2] * 5})
with pytest.raises(RandomnessError):
Expand All @@ -56,6 +58,6 @@ def test_get_random_seed():
rm = RandomnessManager()
rm._add_constraint = lambda f, **kwargs: f
rm._seed = seed
rm._clock = mock_clock
rm._clock_ = mock_clock

assert rm.get_seed(decision_point) == get_hash(f"{decision_point}_{rm._clock()}_{seed}")

0 comments on commit f376f2b

Please sign in to comment.