From ddf0adad7f706d823ba654910907af8dca97e635 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Thu, 7 Nov 2024 22:50:57 +0100 Subject: [PATCH] Make step method state keep track of var_names --- pymc/step_methods/compound.py | 2 ++ pymc/step_methods/state.py | 15 +++++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index d0393afd570..1fcb3d2673f 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -22,6 +22,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence +from dataclasses import field from enum import IntEnum, unique from typing import Any @@ -96,6 +97,7 @@ def infer_warn_stats_info( @dataclass_state class StepMethodState(DataClassState): + var_names: list[str] = field(metadata={"tensor_name": True, "frozen": True}) rng: RandomGeneratorState diff --git a/pymc/step_methods/state.py b/pymc/step_methods/state.py index f50aff7e25a..ec7bbbae483 100644 --- a/pymc/step_methods/state.py +++ b/pymc/step_methods/state.py @@ -67,7 +67,12 @@ def sampling_state(self) -> DataClassState: state_class = self._state_class kwargs = {} for field in fields(state_class): - val = getattr(self, field.name, field.default) + is_tensor_name = field.metadata.get("tensor_name", False) + val: Any + if is_tensor_name: + val = [var.name for var in getattr(self, "vars")] + else: + val = getattr(self, field.name, field.default) if val is MISSING: raise AttributeError( f"{type(self).__name__!r} object has no attribute {field.name!r}" @@ -89,11 +94,17 @@ def sampling_state(self, state: DataClassState): state, state_class ), f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'" for field in fields(state_class): + is_tensor_name = field.metadata.get("tensor_name", False) state_val = deepcopy(getattr(state, field.name)) if isinstance(state_val, RandomGeneratorState): state_val = random_generator_from_state(state_val) - self_val = getattr(self, field.name) is_frozen = field.metadata.get("frozen", False) + self_val: Any + if is_tensor_name: + self_val = [var.name for var in getattr(self, "vars")] + assert is_frozen + else: + self_val = getattr(self, field.name, field.default) if is_frozen: if not equal_dataclass_values(state_val, self_val): raise ValueError(