Skip to content

Commit

Permalink
Remove _skip_checks argument from Model class
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 5, 2024
1 parent 316810c commit c5f2a28
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 33 deletions.
8 changes: 2 additions & 6 deletions src/lcm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dataclasses as dc
from collections.abc import Callable
from dataclasses import KW_ONLY, InitVar, dataclass, field
from dataclasses import KW_ONLY, dataclass, field

from lcm.exceptions import ModelInitilizationError, format_messages
from lcm.grids import Grid
Expand All @@ -29,12 +29,8 @@ class Model:
functions: dict[str, Callable] = field(default_factory=dict)
choices: dict[str, Grid] = field(default_factory=dict)
states: dict[str, Grid] = field(default_factory=dict)
_skip_checks: InitVar[bool] = False

def __post_init__(self, _skip_checks: bool) -> None:
if _skip_checks:
return

def __post_init__(self) -> None:
if type_errors := _validate_attribute_types(self):
msg = format_messages(type_errors)
raise ModelInitilizationError(msg)
Expand Down
37 changes: 23 additions & 14 deletions tests/test_create_params.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from dataclasses import dataclass
from typing import Any

import numpy as np
import pandas as pd
import pytest
Expand All @@ -8,11 +11,25 @@
_create_stochastic_transition_params,
create_params_template,
)
from lcm.model import Model


@dataclass
class ModelMock:
"""A model mock for testing the params creation functions.
This dataclass has the same attributes as the Model dataclass, but does not perform
any checks, which helps us to test the params creation functions in isolation.
"""

n_periods: int | None = None
functions: dict[str, Any] | None = None
choices: dict[str, Any] | None = None
states: dict[str, Any] | None = None


def test_create_params_without_shocks():
model = Model(
model = ModelMock(
functions={
"f": lambda a, b, c: None, # noqa: ARG005
},
Expand All @@ -22,7 +39,6 @@ def test_create_params_without_shocks():
states={
"b": None,
},
_skip_checks=True,
n_periods=None,
)
got = create_params_template(
Expand All @@ -34,7 +50,7 @@ def test_create_params_without_shocks():


def test_create_function_params():
model = Model(
model = ModelMock(
functions={
"f": lambda a, b, c: None, # noqa: ARG005
},
Expand All @@ -44,8 +60,6 @@ def test_create_function_params():
states={
"b": None,
},
_skip_checks=True,
n_periods=None,
)
got = _create_function_params(model)
assert got == {"f": {"c": np.nan}}
Expand All @@ -60,10 +74,9 @@ def next_a(a, _period):
index=["a"],
)

model = Model(
model = ModelMock(
n_periods=3,
functions={"next_a": next_a},
_skip_checks=True,
)

got = _create_stochastic_transition_params(
Expand All @@ -83,10 +96,8 @@ def next_a(a):
index=["a"],
)

model = Model(
model = ModelMock(
functions={"next_a": next_a},
_skip_checks=True,
n_periods=None,
)

with pytest.raises(ValueError, match="The following variables are stochastic, but"):
Expand All @@ -110,10 +121,8 @@ def next_a(a, b, _period):
index=["a", "b"],
)

model = Model(
model = ModelMock(
functions={"next_a": next_a},
_skip_checks=True,
n_periods=None,
)

with pytest.raises(ValueError, match="Stochastic transition functions can only"):
Expand Down
10 changes: 0 additions & 10 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,3 @@ def test_model_overlapping_states_choices():
choices={"health": DiscreteGrid([0, 1])},
functions={"utility": lambda: 0},
)


def test_model_skip_checks():
Model(
n_periods=-1, # invalid number of periods
states={},
choices={},
functions={"utility": lambda: 0},
_skip_checks=True,
)
23 changes: 20 additions & 3 deletions tests/test_process_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from dataclasses import dataclass
from typing import Any

import jax.numpy as jnp
import numpy as np
import pandas as pd
Expand All @@ -6,7 +9,7 @@
from pandas.testing import assert_frame_equal

import lcm.grid_helpers as grids_module
from lcm import DiscreteGrid, LinspaceGrid, Model
from lcm import DiscreteGrid, LinspaceGrid
from lcm.mark import StochasticInfo
from lcm.process_model import (
_get_function_info,
Expand All @@ -21,12 +24,27 @@
)


@dataclass
class ModelMock:
"""A model mock for testing the process_model function.
This dataclass has the same attributes as the Model dataclass, but does not perform
any checks, which helps us to test the process_model function in isolation.
"""

n_periods: int
functions: dict[str, Any]
choices: dict[str, Any]
states: dict[str, Any]


@pytest.fixture
def user_model():
def next_c(a, b):
return a + b

return Model(
return ModelMock(
n_periods=2,
functions={
"next_c": next_c,
Expand All @@ -37,7 +55,6 @@ def next_c(a, b):
states={
"c": DiscreteGrid([0, 1]),
},
_skip_checks=True,
)


Expand Down

0 comments on commit c5f2a28

Please sign in to comment.