From 4d30bf4759df0d5ea3306317c00ca4bd8569f2fe Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 19 Sep 2024 15:58:42 +0200 Subject: [PATCH] Solve #87 --- examples/long_running.py | 13 +- explanations/dispatchers.ipynb | 12 +- explanations/function_representation.ipynb | 10 +- pyproject.toml | 4 +- src/lcm/entry_point.py | 2 +- src/lcm/grids.py | 128 +++++++++++------- .../discrete_state_conversion.py | 64 ++++----- src/lcm/input_processing/process_model.py | 4 +- src/lcm/input_processing/util.py | 2 +- tests/conftest.py | 18 +++ .../test_create_params_template.py | 6 +- .../test_discrete_state_conversion.py | 34 +++-- tests/input_processing/test_process_model.py | 28 ++-- tests/test_grids.py | 81 +++++++---- tests/test_model.py | 12 +- tests/test_models/deterministic.py | 32 +++-- tests/test_models/stochastic.py | 28 +++- tests/test_solution_on_toy_model.py | 26 +++- 18 files changed, 335 insertions(+), 169 deletions(-) diff --git a/examples/long_running.py b/examples/long_running.py index 9c54a5f0..3491785e 100644 --- a/examples/long_running.py +++ b/examples/long_running.py @@ -1,5 +1,7 @@ """Example specification for a consumption-savings model with health and exercise.""" +from dataclasses import dataclass + import jax.numpy as jnp from lcm import DiscreteGrid, LinspaceGrid, Model @@ -9,6 +11,15 @@ # ====================================================================================== +# -------------------------------------------------------------------------------------- +# Categorical variables +# -------------------------------------------------------------------------------------- +@dataclass +class WorkingState: + retired: int = 0 + working: int = 1 + + # -------------------------------------------------------------------------------------- # Utility function # -------------------------------------------------------------------------------------- @@ -67,7 +78,7 @@ def consumption_constraint(consumption, wealth, labor_income): "age": age, }, choices={ - "working": DiscreteGrid([0, 1]), + "working": DiscreteGrid(WorkingState), "consumption": LinspaceGrid( start=1, stop=100, diff --git a/explanations/dispatchers.ipynb b/explanations/dispatchers.ipynb index 81aae290..cab40c16 100644 --- a/explanations/dispatchers.ipynb +++ b/explanations/dispatchers.ipynb @@ -16,6 +16,8 @@ "metadata": {}, "outputs": [], "source": [ + "from dataclasses import dataclass\n", + "\n", "import jax.numpy as jnp\n", "import pytest\n", "from jax import vmap\n", @@ -277,6 +279,12 @@ "from lcm import DiscreteGrid, LinspaceGrid, Model\n", "\n", "\n", + "@dataclass\n", + "class RetirementStatus:\n", + " working: int = 0\n", + " retired: int = 1\n", + "\n", + "\n", "def utility(consumption, retirement, lagged_retirement, wealth):\n", " working = 1 - retirement\n", " retirement_habit = lagged_retirement * wealth\n", @@ -296,11 +304,11 @@ " },\n", " n_periods=1,\n", " choices={\n", - " \"retirement\": DiscreteGrid([0, 1]),\n", + " \"retirement\": DiscreteGrid(RetirementStatus),\n", " \"consumption\": LinspaceGrid(start=1, stop=2, n_points=2),\n", " },\n", " states={\n", - " \"lagged_retirement\": DiscreteGrid([0, 1]),\n", + " \"lagged_retirement\": DiscreteGrid(RetirementStatus),\n", " \"wealth\": LinspaceGrid(start=1, stop=4, n_points=4),\n", " },\n", ")" diff --git a/explanations/function_representation.ipynb b/explanations/function_representation.ipynb index 0e89ddb5..82c4429b 100644 --- a/explanations/function_representation.ipynb +++ b/explanations/function_representation.ipynb @@ -76,11 +76,19 @@ "metadata": {}, "outputs": [], "source": [ + "from dataclasses import dataclass\n", + "\n", "import jax.numpy as jnp\n", "\n", "from lcm import DiscreteGrid, LinspaceGrid, Model\n", "\n", "\n", + "@dataclass\n", + "class RetirementStatus:\n", + " working: int = 0\n", + " retired: int = 1\n", + "\n", + "\n", "def utility(consumption, working, disutility_of_work):\n", " return jnp.log(consumption) - disutility_of_work * working\n", "\n", @@ -125,7 +133,7 @@ " \"age\": age,\n", " },\n", " choices={\n", - " \"retirement\": DiscreteGrid([0, 1]),\n", + " \"retirement\": DiscreteGrid(RetirementStatus),\n", " \"consumption\": LinspaceGrid(start=1, stop=400, n_points=20),\n", " },\n", " states={\n", diff --git a/pyproject.toml b/pyproject.toml index e84f401d..085dccff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -228,10 +228,10 @@ extend-ignore = [ [tool.ruff.lint.per-file-ignores] "docs/source/conf.py" = ["E501", "ERA001", "DTZ005"] "tests/*" = ["PLR2004", "D101"] -"examples/*" = ["INP001"] +"examples/*" = ["INP001", "D101"] "explanations/*" = ["INP001", "B018", "T201", "E402", "PD008"] "scripts/*" = ["INP001", "D101", "RET503"] -"**/*.ipynb" = ["FBT003", "E402"] +"**/*.ipynb" = ["FBT003", "E402", "D101"] [tool.ruff.lint.pydocstyle] convention = "google" diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 19cd5209..19d9a2c1 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -195,7 +195,7 @@ def get_lcm_function( elif targets == "solve_and_simulate": _target = partial(simulate_model, solve_model=solve_model) - user_params = _mod.converter.params_from_internal(_mod.params) + user_params = _mod.converter.internal_to_params(_mod.params) return cast(Callable, _target), user_params diff --git a/src/lcm/grids.py b/src/lcm/grids.py index e390c2e4..ca23c004 100644 --- a/src/lcm/grids.py +++ b/src/lcm/grids.py @@ -16,63 +16,65 @@ class Grid(ABC): """LCM Grid base class.""" @abstractmethod - def to_jax(self) -> jnp.ndarray: + def to_jax(self) -> Array: """Convert the grid to a Jax array.""" -@dataclass(frozen=True) class DiscreteGrid(Grid): - """A grid of discrete values. + """A class representing a discrete grid. + + Args: + category_class (type): The category class representing the grid categories. Must + be a dataclass with fields that have unique scalar int or float values. Attributes: - options: The options in the grid. Must be an iterable of scalar int or float - values. + categories: The list of category names. + codes: The list of category codes. + + Raises: + GridInitializationError: If the `category_class` is not a dataclass with scalar + int or float fields. """ - options: type + def __init__(self, category_class: type) -> None: + """Initialize the DiscreteGrid. - def __post_init__(self) -> None: - if not is_dataclass(self.options): + Args: + category_class (type): The category class representing the grid categories. + Must be a dataclass with fields that have unique scalar int or float + values. + + """ + if not is_dataclass(category_class): raise GridInitializationError( - "options must be a dataclass with scalar int or float fields, but is " - f"{self.options}." + "category_class must be a dataclass with scalar int or float fields, " + f"but is {category_class}." ) - errors = _validate_discrete_grid(self.options) + names_and_values = _get_field_names_and_values(category_class) + + errors = _validate_discrete_grid(names_and_values) if errors: msg = format_messages(errors) raise GridInitializationError(msg) - def to_jax(self) -> Array: - """Convert the grid to a Jax array.""" - return jnp.array(_get_fields(self.options)) - - -def _get_fields(dc: type) -> list[Any]: - """Get the fields of a dataclass. - - Args: - dc: The dataclass to get the fields of. + self.__categories = list(names_and_values.keys()) + self.__codes = list(names_and_values.values()) - Returns: - list[Any]: The fields of the dataclass. + @property + def categories(self) -> list[str]: + """Get the list of category names.""" + return self.__categories - Raises: - GridInitializationError: If the fields of the dataclass do not have default - values, or the instantiated dataclass does not have all fields. None values - are treated as if they do not exist. + @property + def codes(self) -> list[int | float]: + """Get the list of category codes.""" + return self.__codes - """ - _fields = {field.name: getattr(dc, field.name, None) for field in fields(dc)} - fields_without_defaults = [name for name, value in _fields.items() if value is None] - if fields_without_defaults: - raise GridInitializationError( - f"To use a DiscreteGrid, all fields of the options dataclass must have " - f"default values. The following fields do not have default values: " - f"{fields_without_defaults}." - ) - return list(_fields.values()) + def to_jax(self) -> Array: + """Convert the grid to a Jax array.""" + return jnp.array(self.codes) @dataclass(frozen=True, kw_only=True) @@ -157,32 +159,62 @@ def get_coordinate(self, value: Scalar) -> Scalar: # ====================================================================================== -def _validate_discrete_grid(options: type) -> list[str]: - """Validate the discrete grid options. +def _validate_discrete_grid(names_and_values: dict[str, Any]) -> list[str]: + """Validate the field names and values of the category_class passed to DiscreteGrid. Args: - options: The user options to validate in form of a dataclass. + names_and_values: A dictionary with the field names as keys and the field + values as values. Returns: list[str]: A list of error messages. """ - values = _get_fields(options) - error_messages = [] - if not len(values) > 0: - error_messages.append("options must contain at least one element") + if not len(names_and_values) > 0: + error_messages.append( + "category_class passed to DiscreteGrid must have at least one field" + ) - if not all(isinstance(value, int | float) for value in values): - error_messages.append("options must contain only scalar int or float values") + names_with_non_numerical_values = [ + name + for name, value in names_and_values.items() + if not isinstance(value, int | float) + ] + if names_with_non_numerical_values: + error_messages.append( + "Field values of the category_class passed to DiscreteGrid can only be " + "scalar int or float values. The values to the following fields are not: " + f"{names_with_non_numerical_values}" + ) - if len(values) != len(set(values)): - error_messages.append("options must contain unique values") + values = list(names_and_values.values()) + duplicated_values = [v for v in values if values.count(v) > 1] + if duplicated_values: + error_messages.append( + "Field values of the category_class passed to DiscreteGrid must be unique. " + "The following values are duplicated: " + f"{set(duplicated_values)}" + ) return error_messages +def _get_field_names_and_values(dc: type) -> dict[str, Any]: + """Get the fields of a dataclass. + + Args: + dc: The dataclass to get the fields of. + + Returns: + A dictionary with the field names as keys and the field values as values. If + no value is provided for a field, the value is set to None. + + """ + return {field.name: getattr(dc, field.name, None) for field in fields(dc)} + + def _validate_continuous_grid( start: float, stop: float, diff --git a/src/lcm/input_processing/discrete_state_conversion.py b/src/lcm/input_processing/discrete_state_conversion.py index 5209b4c1..2dd3991e 100644 --- a/src/lcm/input_processing/discrete_state_conversion.py +++ b/src/lcm/input_processing/discrete_state_conversion.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from dataclasses import dataclass, field +from dataclasses import dataclass, field, make_dataclass import jax.numpy as jnp from dags.signature import with_signature @@ -35,7 +35,7 @@ class DiscreteStateConverter: index_to_label: dict[str, Callable[[Array], Array]] = field(default_factory=dict) label_to_index: dict[str, Callable[[Array], Array]] = field(default_factory=dict) - def params_from_internal(self, params: ParamsDict) -> ParamsDict: + def internal_to_params(self, params: ParamsDict) -> ParamsDict: """Convert parameters from internal to external representation. If a state has been converted, the name of its corresponding next function must @@ -61,7 +61,7 @@ def params_to_internal(self, params: ParamsDict) -> ParamsDict: out[f"next___{var}_index__"] = params[f"next_{var}"] return out - def states_from_internal(self, states: dict[str, Array]) -> dict[str, Array]: + def internal_to_states(self, states: dict[str, Array]) -> dict[str, Array]: """Convert states from internal to external representation. If a state has been converted, the name of its corresponding index function must @@ -90,30 +90,30 @@ def states_to_internal(self, states: dict[str, Array]) -> dict[str, Array]: return out -def convert_discrete_options_to_indices( +def convert_discrete_codes_to_indices( model: Model, ) -> tuple[Model, DiscreteStateConverter]: - """Update the user model to ensure that discrete variables have index options. + """Update the user model to ensure that discrete variables have index codes. - For each discrete variable with non-index options, we: + For each discrete variable with non-index codes, we: 1. Remove the variable from the states or choices dictionary - 2. Replace it with a new state or choice with index options (__{var}_index__) + 2. Replace it with a new state or choice with index codes (__{var}_index__) 3. Add updated next functions (if the variable was a state variable) - 4. Add a function that maps the index options to the original options + 4. Add a function that maps the index codes to the original codes Args: model: The model as provided by the user. Returns: - - The model with all discrete variables having index options. + - The model with all discrete variables having index codes. - A converter that can be used to convert between the internal and external representation of the model. """ gridspecs = get_gridspecs(model) - non_index_discrete_vars = _get_discrete_vars_with_non_index_options(model) + non_index_discrete_vars = _get_discrete_vars_with_non_index_codes(model) # fast path if not non_index_discrete_vars: @@ -127,7 +127,11 @@ def convert_discrete_options_to_indices( # ---------------------------------------------------------------------------------- for var in non_index_discrete_vars: grid: DiscreteGrid = gridspecs[var] # type: ignore[assignment] - index_grid = DiscreteGrid(options=list(range(len(grid.options)))) + index_category_class = make_dataclass( + grid.__str__(), + [(f"__{name}_index__", int, i) for i, name in enumerate(grid.categories)], + ) + index_grid = DiscreteGrid(index_category_class) if var in model.states: states.pop(var) @@ -146,7 +150,7 @@ def convert_discrete_options_to_indices( # Add index to label functions # ---------------------------------------------------------------------------------- index_to_label_funcs = { - var: _get_index_to_label_func(gridspecs[var].to_jax(), name=var) + var: _get_index_to_code_func(gridspecs[var].to_jax(), name=var) for var in non_index_discrete_vars } functions = functions | index_to_label_funcs @@ -156,7 +160,7 @@ def convert_discrete_options_to_indices( converted_states = [s for s in non_index_discrete_vars if s in model.states] label_to_index_funcs_for_states = { - var: _get_label_to_index_func(gridspecs[var].to_jax(), name=var) + var: _get_code_to_index_func(gridspecs[var].to_jax(), name=var) for var in converted_states } @@ -179,34 +183,32 @@ def convert_discrete_options_to_indices( return new_model, converter -def _get_discrete_vars_with_non_index_options(model: Model) -> list[str]: - """Get discrete variables with non-index options. +def _get_discrete_vars_with_non_index_codes(model: Model) -> list[str]: + """Get discrete variables with non-index codes. - Collect all discrete variables with options that do not correspond to indices. + Collect all discrete variables with codes that do not correspond to indices. """ gridspecs = get_gridspecs(model) discrete_vars = [] for name, spec in gridspecs.items(): - if isinstance(spec, DiscreteGrid) and list(spec.options) != list( - range(len(spec.options)) + if isinstance(spec, DiscreteGrid) and list(spec.codes) != list( + range(len(spec.codes)) ): discrete_vars.append(name) return discrete_vars -def _get_index_to_label_func( - labels_array: Array, name: str -) -> Callable[[Array], Array]: - """Get function mapping from index to label. +def _get_index_to_code_func(codes_array: Array, name: str) -> Callable[[Array], Array]: + """Get function mapping from index to code. Args: - labels_array: An array of labels. + codes_array: An array of codes. name: The name of resulting function argument. Returns: - A function mapping an array with indices corresponding to labels_array to the - corresponding labels. + A function mapping an array with indices corresponding to codes_array to the + corresponding codes. """ arg_name = f"__{name}_index__" @@ -215,22 +217,20 @@ def _get_index_to_label_func( def func(*args, **kwargs): kwargs = all_as_kwargs(args, kwargs, arg_names=[arg_name]) index = kwargs[arg_name] - return labels_array[index] + return codes_array[index] return func -def _get_label_to_index_func( - labels_array: Array, name: str -) -> Callable[[Array], Array]: +def _get_code_to_index_func(codes_array: Array, name: str) -> Callable[[Array], Array]: """Get function mapping from label to index. Args: - labels_array: An array of labels. + codes_array: An array of codes. name: The name of resulting function argument. Returns: - A function mapping an array with values in labels_array to their corresponding + A function mapping an array with values in codes_array to their corresponding indices. """ @@ -239,6 +239,6 @@ def _get_label_to_index_func( def label_to_index(*args, **kwargs): kwargs = all_as_kwargs(args, kwargs, arg_names=[name]) data = kwargs[name] - return jnp.argmax(data[:, None] == labels_array[None, :], axis=1) + return jnp.argmax(data[:, None] == codes_array[None, :], axis=1) return label_to_index diff --git a/src/lcm/input_processing/process_model.py b/src/lcm/input_processing/process_model.py index aac68b68..17452ac3 100644 --- a/src/lcm/input_processing/process_model.py +++ b/src/lcm/input_processing/process_model.py @@ -10,7 +10,7 @@ from lcm.functools import all_as_args, all_as_kwargs from lcm.input_processing.create_params_template import create_params_template from lcm.input_processing.discrete_state_conversion import ( - convert_discrete_options_to_indices, + convert_discrete_codes_to_indices, ) from lcm.input_processing.util import ( get_function_info, @@ -39,7 +39,7 @@ def process_model(model: Model) -> InternalModel: The processed model. """ - tmp_model, converter = convert_discrete_options_to_indices(model) + tmp_model, converter = convert_discrete_codes_to_indices(model) params = create_params_template(tmp_model) diff --git a/src/lcm/input_processing/util.py b/src/lcm/input_processing/util.py index 70a1a304..5392c8a7 100644 --- a/src/lcm/input_processing/util.py +++ b/src/lcm/input_processing/util.py @@ -133,7 +133,7 @@ def get_gridspecs( Returns: dict: Dictionary containing all variables of the model. The keys are the names of the variables. The values describe which values the variable - can take. For discrete variables these are the options. For continuous + can take. For discrete variables these are the codes. For continuous variables this is information about how to build the grids. """ diff --git a/tests/conftest.py b/tests/conftest.py index a7a9a731..695ef46d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,23 @@ +from dataclasses import make_dataclass + +import pytest from jax import config def pytest_sessionstart(session): # noqa: ARG001 config.update("jax_enable_x64", val=True) + + +def _category_class_factory(values: list[int]): + init = [(f"cat{i}", int, value) for i, value in enumerate(values)] + return make_dataclass("CategoryClass", init) + + +@pytest.fixture(scope="session") +def category_class_factory(): + return _category_class_factory + + +@pytest.fixture(scope="session") +def binary_category_class(category_class_factory): + return category_class_factory([0, 1]) diff --git a/tests/input_processing/test_create_params_template.py b/tests/input_processing/test_create_params_template.py index e0b9c0d5..475831fc 100644 --- a/tests/input_processing/test_create_params_template.py +++ b/tests/input_processing/test_create_params_template.py @@ -29,17 +29,17 @@ class ModelMock: states: dict[str, Any] | None = None -def test_create_params_without_shocks(): +def test_create_params_without_shocks(binary_category_class): model = ModelMock( functions={ "f": lambda a, b, c: None, # noqa: ARG005 "next_b": lambda b: b, }, choices={ - "a": DiscreteGrid([0, 1]), + "a": DiscreteGrid(binary_category_class), }, states={ - "b": DiscreteGrid([0, 1]), + "b": DiscreteGrid(binary_category_class), }, n_periods=None, ) diff --git a/tests/input_processing/test_discrete_state_conversion.py b/tests/input_processing/test_discrete_state_conversion.py index f34efe2b..5de2effd 100644 --- a/tests/input_processing/test_discrete_state_conversion.py +++ b/tests/input_processing/test_discrete_state_conversion.py @@ -7,9 +7,10 @@ from lcm import DiscreteGrid from lcm.input_processing.discrete_state_conversion import ( - _get_discrete_vars_with_non_index_options, - _get_index_to_label_func, - convert_discrete_options_to_indices, + _get_code_to_index_func, + _get_discrete_vars_with_non_index_codes, + _get_index_to_code_func, + convert_discrete_codes_to_indices, ) @@ -29,7 +30,7 @@ class ModelMock: @pytest.fixture -def model(): +def model(category_class_factory): def next_c(a, b): return a + b @@ -39,35 +40,42 @@ def next_c(a, b): "next_c": next_c, }, choices={ - "a": DiscreteGrid([0, 1]), + "a": DiscreteGrid(category_class_factory([0, 1])), }, states={ - "c": DiscreteGrid([1, 10]), + "c": DiscreteGrid(category_class_factory([1, 10])), }, ) def test_get_index_to_label_func(): - labels = jnp.array([1, 10]) - got = _get_index_to_label_func(labels_array=labels, name="foo") + codes_array = jnp.array([1, 10]) + got = _get_index_to_code_func(codes_array, name="foo") assert got(__foo_index__=0) == 1 assert got(1) == 10 -def test_get_discrete_vars_with_non_index_options(model): - got = _get_discrete_vars_with_non_index_options(model) +def test_get_code_to_index_func(): + codes_array = jnp.array([1, 10]) + got = _get_code_to_index_func(codes_array, name="foo") + assert_array_equal(got(foo=codes_array), jnp.arange(2)) + + +def test_get_discrete_vars_with_non_index_codes(model): + got = _get_discrete_vars_with_non_index_codes(model) assert got == ["c"] -def test_convert_discrete_options_to_indices(model): +def test_convert_discrete_codes_to_indices(model): # add replace method to model mock model.replace = lambda **kwargs: ModelMock(**kwargs, n_periods=model.n_periods) - got, _ = convert_discrete_options_to_indices(model) + got, _ = convert_discrete_codes_to_indices(model) assert "c" not in got.states assert "__c_index__" in got.states assert "c" in got.functions - assert_array_equal(got.states["__c_index__"], DiscreteGrid([0, 1])) + assert got.states["__c_index__"].categories == ["__cat0_index__", "__cat1_index__"] + assert got.states["__c_index__"].codes == [0, 1] assert got.functions["c"](0) == 1 assert got.functions["c"](1) == 10 diff --git a/tests/input_processing/test_process_model.py b/tests/input_processing/test_process_model.py index df6f2184..17b7c77c 100644 --- a/tests/input_processing/test_process_model.py +++ b/tests/input_processing/test_process_model.py @@ -39,7 +39,7 @@ class ModelMock: @pytest.fixture -def model(): +def model(category_class_factory): def next_c(a, b): return a + b @@ -49,10 +49,10 @@ def next_c(a, b): "next_c": next_c, }, choices={ - "a": DiscreteGrid([0, 1]), + "a": DiscreteGrid(category_class_factory([0, 1])), }, states={ - "c": DiscreteGrid([1, 10]), + "c": DiscreteGrid(category_class_factory([1, 10])), }, ) @@ -91,8 +91,13 @@ def test_get_variable_info(model): def test_get_gridspecs(model): got = get_gridspecs(model) - assert got["a"] == DiscreteGrid([0, 1]) - assert got["c"] == DiscreteGrid([1, 10]) + assert isinstance(got["a"], DiscreteGrid) + assert got["a"].categories == ["cat0", "cat1"] + assert got["a"].codes == [0, 1] + + assert isinstance(got["c"], DiscreteGrid) + assert got["c"].categories == ["cat0", "cat1"] + assert got["c"].codes == [1, 10] def test_get_grids(model): @@ -137,8 +142,13 @@ def test_process_model_iskhakov_et_al_2017(): ) assert model.gridspecs["consumption"] == consumption_grid - assert model.gridspecs["retirement"] == DiscreteGrid([0, 1]) - assert model.gridspecs["lagged_retirement"] == DiscreteGrid([0, 1]) + assert isinstance(model.gridspecs["retirement"], DiscreteGrid) + assert model.gridspecs["retirement"].categories == ["working", "retired"] + assert model.gridspecs["retirement"].codes == [0, 1] + + assert isinstance(model.gridspecs["lagged_retirement"], DiscreteGrid) + assert model.gridspecs["lagged_retirement"].categories == ["working", "retired"] + assert model.gridspecs["lagged_retirement"].codes == [0, 1] # Grids expected = grid_helpers.linspace(**model_config.choices["consumption"].__dict__) @@ -195,7 +205,9 @@ def test_process_model(): ) assert model.gridspecs["consumption"] == consumption_specs - assert model.gridspecs["retirement"] == DiscreteGrid([0, 1]) + assert isinstance(model.gridspecs["retirement"], DiscreteGrid) + assert model.gridspecs["retirement"].categories == ["working", "retired"] + assert model.gridspecs["retirement"].codes == [0, 1] # Grids expected = grid_helpers.linspace(**model_config.choices["consumption"].__dict__) diff --git a/tests/test_grids.py b/tests/test_grids.py index 61bf614c..facf5614 100644 --- a/tests/test_grids.py +++ b/tests/test_grids.py @@ -8,52 +8,71 @@ DiscreteGrid, LinspaceGrid, LogspaceGrid, - _get_fields, + _get_field_names_and_values, _validate_continuous_grid, _validate_discrete_grid, ) def test_validate_discrete_grid_empty(): - options = make_dataclass("Options", []) - assert _validate_discrete_grid(options) == [ - "options must contain at least one element" - ] + names_and_values = {} + expected = ["category_class passed to DiscreteGrid must have at least one field"] + assert _validate_discrete_grid(names_and_values) == expected def test_validate_discrete_grid_non_scalar_input(): - options = make_dataclass("Options", [("a", int, 1), ("b", str, "wrong_type")]) - assert _validate_discrete_grid(options) == [ - "options must contain only scalar int or float values", + names_and_values = {"a": 1, "b": "wrong_type"} + expected = [ + ( + "Field values of the category_class passed to DiscreteGrid can only be " + "scalar int or float values. The values to the following fields are not: " + "['b']" + ) ] + assert _validate_discrete_grid(names_and_values) == expected + + +def test_validate_discrete_grid_none_input(): + names_and_values = {"a": None, "b": 1} + expected = [ + ( + "Field values of the category_class passed to DiscreteGrid can only be " + "scalar int or float values. The values to the following fields are not: " + "['a']" + ) + ] + assert _validate_discrete_grid(names_and_values) == expected def test_validate_discrete_grid_non_unique(): - options = make_dataclass("Options", [("a", int, 1), ("b", int, 2), ("c", int, 2)]) - assert _validate_discrete_grid(options) == [ - "options must contain unique values", + names_and_values = {"a": 1, "b": 2, "c": 2} + expected = [ + ( + "Field values of the category_class passed to DiscreteGrid must be unique. " + "The following values are duplicated: {2}" + ) ] + assert _validate_discrete_grid(names_and_values) == expected def test_get_fields_with_defaults(): - options = make_dataclass("Options", [("a", int, 1), ("b", int, 2), ("c", int, 3)]) - assert _get_fields(options) == [1, 2, 3] + category_class = make_dataclass("Category", [("a", int, 1), ("b", int, 2)]) + assert _get_field_names_and_values(category_class) == {"a": 1, "b": 2} -def test_get_fields_instance(): - options = make_dataclass("Options", [("a", int), ("b", int)]) - assert _get_fields(options(a=1, b=2)) == [1, 2] +def test_get_fields_no_defaults(): + category_class = make_dataclass("Category", [("a", int), ("b", int)]) + assert _get_field_names_and_values(category_class) == {"a": None, "b": None} -def test_get_fields_empty(): - options = make_dataclass("Options", []) - assert _get_fields(options) == [] +def test_get_fields_instance(): + category_class = make_dataclass("Category", [("a", int), ("b", int)]) + assert _get_field_names_and_values(category_class(a=1, b=2)) == {"a": 1, "b": 2} -def test_get_fields_no_defaults(): - options = make_dataclass("Options", [("a", int), ("b", int)]) - with pytest.raises(GridInitializationError, match="To use a DiscreteGrid"): - _get_fields(options) +def test_get_fields_empty(): + category_class = make_dataclass("Category", []) + assert _get_field_names_and_values(category_class) == {} def test_validate_continuous_grid_invalid_start(): @@ -95,8 +114,10 @@ def test_logspace_grid_creation(): def test_discrete_grid_creation(): - options = make_dataclass("Options", [("a", int, 0), ("b", int, 1), ("c", int, 2)]) - grid = DiscreteGrid(options) + category_class = make_dataclass( + "Category", [("a", int, 0), ("b", int, 1), ("c", int, 2)] + ) + grid = DiscreteGrid(category_class) assert np.allclose(grid.to_jax(), np.arange(3)) @@ -110,10 +131,12 @@ def test_logspace_grid_invalid_start(): LogspaceGrid(start=1, stop=0, n_points=10) -def test_discrete_grid_invalid_options(): - options = make_dataclass("Options", [("a", int, 1), ("b", str, "wrong_type")]) +def test_discrete_grid_invalid_category_class(): + category_class = make_dataclass( + "Category", [("a", int, 1), ("b", str, "wrong_type")] + ) with pytest.raises( GridInitializationError, - match="options must contain only scalar int or float values", + match="Field values of the category_class passed to DiscreteGrid can only be", ): - DiscreteGrid(options) + DiscreteGrid(category_class) diff --git a/tests/test_model.py b/tests/test_model.py index 547a7762..4536946b 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -95,15 +95,15 @@ def test_model_invalid_n_periods(): ) -def test_model_missing_next_func(): +def test_model_missing_next_func(binary_category_class): with pytest.raises( ModelInitilizationError, match="Each state must have a corresponding next state function.", ): Model( n_periods=2, - states={"health": DiscreteGrid([0, 1])}, - choices={"exercise": DiscreteGrid([0, 1])}, + states={"health": DiscreteGrid(binary_category_class)}, + choices={"exercise": DiscreteGrid(binary_category_class)}, functions={"utility": lambda: 0}, ) @@ -121,14 +121,14 @@ def test_model_missing_utility(): ) -def test_model_overlapping_states_choices(): +def test_model_overlapping_states_choices(binary_category_class): with pytest.raises( ModelInitilizationError, match="States and choices cannot have overlapping names.", ): Model( n_periods=2, - states={"health": DiscreteGrid([0, 1])}, - choices={"health": DiscreteGrid([0, 1])}, + states={"health": DiscreteGrid(binary_category_class)}, + choices={"health": DiscreteGrid(binary_category_class)}, functions={"utility": lambda: 0}, ) diff --git a/tests/test_models/deterministic.py b/tests/test_models/deterministic.py index e99b0031..e53ac12c 100644 --- a/tests/test_models/deterministic.py +++ b/tests/test_models/deterministic.py @@ -8,7 +8,7 @@ """ from copy import deepcopy -from dataclasses import dataclass +from dataclasses import dataclass, make_dataclass import jax.numpy as jnp @@ -20,12 +20,7 @@ # -------------------------------------------------------------------------------------- -# Labels -# -------------------------------------------------------------------------------------- -# Dataclasses can be used to represent labeled versions of discrete grids. For this, you -# need to define a dataclass with one field per label with the same value as used in the -# grid definition. They are especially useful to make case distinctions more readable. -# One example can be found in the absorbing_retirement_filter function below. +# Categorical variables # -------------------------------------------------------------------------------------- @dataclass class RetirementStatus: @@ -124,7 +119,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "working": working, }, choices={ - "retirement": DiscreteGrid([0, 1]), + "retirement": DiscreteGrid(RetirementStatus), "consumption": LinspaceGrid( start=1, stop=400, @@ -137,7 +132,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): stop=400, n_points=100, ), - "lagged_retirement": DiscreteGrid([0, 1]), + "lagged_retirement": DiscreteGrid(RetirementStatus), }, ) @@ -158,7 +153,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "age": age, }, choices={ - "retirement": DiscreteGrid([0, 1]), + "retirement": DiscreteGrid(RetirementStatus), "consumption": LinspaceGrid( start=1, stop=400, @@ -175,6 +170,17 @@ def absorbing_retirement_filter(retirement, lagged_retirement): ) +@dataclass +class DiscreteConsumptionStatus: + low: int = 1 + high: int = 2 + + +DiscreteWealthStatus = make_dataclass( + "DiscreteWealthStatus", [(f"level_{w}", int, w) for w in range(1, 401)] +) + + ISKHAKOV_ET_AL_2017_FULLY_DISCRETE = Model( description=( "Starts from Iskhakov et al. (2017), removes filters and the lagged_retirement " @@ -189,11 +195,11 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "working": working, }, choices={ - "retirement": DiscreteGrid([0, 1]), - "consumption": DiscreteGrid([1, 2]), + "retirement": DiscreteGrid(RetirementStatus), + "consumption": DiscreteGrid(DiscreteConsumptionStatus), }, states={ - "wealth": DiscreteGrid(list(range(1, 401))), + "wealth": DiscreteGrid(DiscreteWealthStatus), }, ) diff --git a/tests/test_models/stochastic.py b/tests/test_models/stochastic.py index e5e96005..9ff7b57c 100644 --- a/tests/test_models/stochastic.py +++ b/tests/test_models/stochastic.py @@ -10,6 +10,7 @@ """ from copy import deepcopy +from dataclasses import dataclass import jax.numpy as jnp @@ -21,6 +22,27 @@ # ====================================================================================== +# -------------------------------------------------------------------------------------- +# Categorical variables +# -------------------------------------------------------------------------------------- +@dataclass +class HealthStatus: + bad: int = 0 + good: int = 1 + + +@dataclass +class PartnerStatus: + single: int = 0 + partnered: int = 1 + + +@dataclass +class WorkingState: + retired: int = 0 + working: int = 1 + + # -------------------------------------------------------------------------------------- # Utility function # -------------------------------------------------------------------------------------- @@ -91,7 +113,7 @@ def consumption_constraint(consumption, wealth): "labor_income": labor_income, }, choices={ - "working": DiscreteGrid([0, 1]), + "working": DiscreteGrid(WorkingState), "consumption": LinspaceGrid( start=1, stop=100, @@ -99,8 +121,8 @@ def consumption_constraint(consumption, wealth): ), }, states={ - "health": DiscreteGrid([0, 1]), - "partner": DiscreteGrid([0, 1]), + "health": DiscreteGrid(HealthStatus), + "partner": DiscreteGrid(PartnerStatus), "wealth": LinspaceGrid( start=1, stop=100, diff --git a/tests/test_solution_on_toy_model.py b/tests/test_solution_on_toy_model.py index 3da3b008..5a02493a 100644 --- a/tests/test_solution_on_toy_model.py +++ b/tests/test_solution_on_toy_model.py @@ -1,7 +1,7 @@ """Test analytical solution and simulation with only discrete choices.""" from copy import deepcopy -from dataclasses import replace +from dataclasses import dataclass, replace import jax.numpy as jnp import numpy as np @@ -18,6 +18,24 @@ # ====================================================================================== # Model specification # ====================================================================================== +@dataclass +class ConsumptionChoice: + low: int = 0 + high: int = 1 + + +@dataclass +class WorkingState: + retired: int = 0 + working: int = 1 + + +@dataclass +class HealthState: + bad: int = 0 + good: int = 1 + + def utility(consumption, working, wealth, health): # noqa: ARG001 return jnp.log(1 + health * consumption) - 0.5 * working @@ -38,8 +56,8 @@ def consumption_constraint(consumption, wealth): }, n_periods=2, choices={ - "consumption": DiscreteGrid([0, 1]), - "working": DiscreteGrid([0, 1]), + "consumption": DiscreteGrid(ConsumptionChoice), + "working": DiscreteGrid(WorkingState), }, states={ "wealth": LinspaceGrid( @@ -58,7 +76,7 @@ def next_health(health): STOCHASTIC_MODEL = deepcopy(DETERMINISTIC_MODEL) STOCHASTIC_MODEL.functions["next_health"] = next_health -STOCHASTIC_MODEL.states["health"] = DiscreteGrid([0, 1]) +STOCHASTIC_MODEL.states["health"] = DiscreteGrid(HealthState) # ======================================================================================