diff --git a/src/lcm/model.py b/src/lcm/model.py index 24e2b9c..5ee7456 100644 --- a/src/lcm/model.py +++ b/src/lcm/model.py @@ -2,7 +2,7 @@ import dataclasses as dc from collections.abc import Callable -from dataclasses import KW_ONLY, InitVar, dataclass, field +from dataclasses import KW_ONLY, dataclass, field from lcm.exceptions import ModelInitilizationError, format_messages from lcm.grids import Grid @@ -29,12 +29,8 @@ class Model: functions: dict[str, Callable] = field(default_factory=dict) choices: dict[str, Grid] = field(default_factory=dict) states: dict[str, Grid] = field(default_factory=dict) - _skip_checks: InitVar[bool] = False - - def __post_init__(self, _skip_checks: bool) -> None: - if _skip_checks: - return + def __post_init__(self) -> None: if type_errors := _validate_attribute_types(self): msg = format_messages(type_errors) raise ModelInitilizationError(msg) diff --git a/tests/test_create_params.py b/tests/test_create_params.py index 1b24c01..801d525 100644 --- a/tests/test_create_params.py +++ b/tests/test_create_params.py @@ -1,3 +1,6 @@ +from dataclasses import dataclass +from typing import Any + import numpy as np import pandas as pd import pytest @@ -8,11 +11,25 @@ _create_stochastic_transition_params, create_params_template, ) -from lcm.model import Model + + +@dataclass +class ModelMock: + """A model mock for testing the params creation functions. + + This dataclass has the same attributes as the Model dataclass, but does not perform + any checks, which helps us to test the params creation functions in isolation. + + """ + + n_periods: int | None = None + functions: dict[str, Any] | None = None + choices: dict[str, Any] | None = None + states: dict[str, Any] | None = None def test_create_params_without_shocks(): - model = Model( + model = ModelMock( functions={ "f": lambda a, b, c: None, # noqa: ARG005 }, @@ -22,7 +39,6 @@ def test_create_params_without_shocks(): states={ "b": None, }, - _skip_checks=True, n_periods=None, ) got = create_params_template( @@ -34,7 +50,7 @@ def test_create_params_without_shocks(): def test_create_function_params(): - model = Model( + model = ModelMock( functions={ "f": lambda a, b, c: None, # noqa: ARG005 }, @@ -44,8 +60,6 @@ def test_create_function_params(): states={ "b": None, }, - _skip_checks=True, - n_periods=None, ) got = _create_function_params(model) assert got == {"f": {"c": np.nan}} @@ -60,10 +74,9 @@ def next_a(a, _period): index=["a"], ) - model = Model( + model = ModelMock( n_periods=3, functions={"next_a": next_a}, - _skip_checks=True, ) got = _create_stochastic_transition_params( @@ -83,10 +96,8 @@ def next_a(a): index=["a"], ) - model = Model( + model = ModelMock( functions={"next_a": next_a}, - _skip_checks=True, - n_periods=None, ) with pytest.raises(ValueError, match="The following variables are stochastic, but"): @@ -110,10 +121,8 @@ def next_a(a, b, _period): index=["a", "b"], ) - model = Model( + model = ModelMock( functions={"next_a": next_a}, - _skip_checks=True, - n_periods=None, ) with pytest.raises(ValueError, match="Stochastic transition functions can only"): diff --git a/tests/test_model.py b/tests/test_model.py index eb747fe..b0851c5 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -132,13 +132,3 @@ def test_model_overlapping_states_choices(): choices={"health": DiscreteGrid([0, 1])}, functions={"utility": lambda: 0}, ) - - -def test_model_skip_checks(): - Model( - n_periods=-1, # invalid number of periods - states={}, - choices={}, - functions={"utility": lambda: 0}, - _skip_checks=True, - ) diff --git a/tests/test_process_model.py b/tests/test_process_model.py index fac765c..552dc87 100644 --- a/tests/test_process_model.py +++ b/tests/test_process_model.py @@ -1,3 +1,6 @@ +from dataclasses import dataclass +from typing import Any + import jax.numpy as jnp import numpy as np import pandas as pd @@ -6,7 +9,7 @@ from pandas.testing import assert_frame_equal import lcm.grid_helpers as grids_module -from lcm import DiscreteGrid, LinspaceGrid, Model +from lcm import DiscreteGrid, LinspaceGrid from lcm.mark import StochasticInfo from lcm.process_model import ( _get_function_info, @@ -21,12 +24,27 @@ ) +@dataclass +class ModelMock: + """A model mock for testing the process_model function. + + This dataclass has the same attributes as the Model dataclass, but does not perform + any checks, which helps us to test the process_model function in isolation. + + """ + + n_periods: int + functions: dict[str, Any] + choices: dict[str, Any] + states: dict[str, Any] + + @pytest.fixture def user_model(): def next_c(a, b): return a + b - return Model( + return ModelMock( n_periods=2, functions={ "next_c": next_c, @@ -37,7 +55,6 @@ def next_c(a, b): states={ "c": DiscreteGrid([0, 1]), }, - _skip_checks=True, )