From 420eedd06a4e9aae7d320ba3695da78df1f88c29 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 23 Sep 2024 13:31:31 +0200 Subject: [PATCH] Remove discrete grid converter --- src/lcm/entry_point.py | 5 +- src/lcm/input_processing/__init__.py | 3 +- .../discrete_grid_conversion.py | 280 ------------------ src/lcm/input_processing/process_model.py | 20 +- src/lcm/interfaces.py | 4 - src/lcm/simulate.py | 25 +- src/lcm/solve_brute.py | 9 +- .../test_discrete_state_conversion.py | 155 ---------- tests/test_model_functions.py | 1 - tests/test_next_state.py | 1 - tests/test_simulate.py | 15 +- tests/test_solve_brute.py | 2 - tests/test_state_space.py | 1 - 13 files changed, 22 insertions(+), 499 deletions(-) delete mode 100644 src/lcm/input_processing/discrete_grid_conversion.py delete mode 100644 tests/input_processing/test_discrete_state_conversion.py diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index f41c247..c588999 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -170,7 +170,6 @@ def get_lcm_function( continuous_choice_grids=continuous_choice_grids, compute_ccv_functions=compute_ccv_functions, emax_calculators=emax_calculators, - discrete_grid_converter=_mod.discrete_grid_converter, logger=logger, ) @@ -184,7 +183,6 @@ def get_lcm_function( compute_ccv_policy_functions=compute_ccv_policy_functions, model=_mod, next_state=jax.jit(_next_state_simulate), - discrete_grid_converter=_mod.discrete_grid_converter, logger=logger, ) @@ -195,8 +193,7 @@ def get_lcm_function( elif targets == "solve_and_simulate": _target = partial(simulate_model, solve_model=solve_model) - user_params = _mod.discrete_grid_converter.internal_params_to_params(_mod.params) - return cast(Callable, _target), user_params + return cast(Callable, _target), _mod.params def create_compute_conditional_continuation_value( diff --git a/src/lcm/input_processing/__init__.py b/src/lcm/input_processing/__init__.py index 546866d..849cbdc 100644 --- a/src/lcm/input_processing/__init__.py +++ b/src/lcm/input_processing/__init__.py @@ -1,4 +1,3 @@ -from .discrete_grid_conversion import DiscreteGridConverter from .process_model import process_model -__all__ = ["process_model", "DiscreteGridConverter"] +__all__ = ["process_model"] diff --git a/src/lcm/input_processing/discrete_grid_conversion.py b/src/lcm/input_processing/discrete_grid_conversion.py deleted file mode 100644 index a1084e2..0000000 --- a/src/lcm/input_processing/discrete_grid_conversion.py +++ /dev/null @@ -1,280 +0,0 @@ -import functools -from collections.abc import Callable -from dataclasses import dataclass, field, make_dataclass -from typing import TypeVar, cast - -import jax.numpy as jnp -from dags.signature import with_signature -from jax import Array - -from lcm.functools import all_as_kwargs -from lcm.grids import DiscreteGrid -from lcm.input_processing.util import ( - get_gridspecs, -) -from lcm.typing import ParamsDict -from lcm.user_model import Model - - -@dataclass(frozen=True, kw_only=True) -class DiscreteGridConverter: - """Converts between representations of discrete variables and associated parameters. - - While LCM supports general discrete grids, internally, these are converted to - array indices. This class provides functionality for converting between the internal - representation and the external representation of states, choices, and associated - parameters. - - Attributes: - index_to_code: A dictionary of functions mapping from the internal index to the - code for each converted state. Keys correspond to the names of converted - discrete variables. - code_to_index: A dictionary of functions mapping from the code to the internal - index for each converted state. Keys correspond to the names of converted - discrete variables. - - """ - - index_to_code: dict[str, Callable[[Array], Array]] = field(default_factory=dict) - code_to_index: dict[str, Callable[[Array], Array]] = field(default_factory=dict) - - def __post_init__(self) -> None: - if set(self.index_to_code.keys()) != set(self.code_to_index.keys()): - raise ValueError( - "The keys of index_to_code and code_to_index must be the same." - ) - - def internal_params_to_params(self, internal: ParamsDict) -> ParamsDict: - """Convert parameters from internal to external representation. - - If a state has been converted, the name of its corresponding next function must - be changed from `next___{var}_index__` to `next_{var}`. - - """ - params = internal.copy() - for var in self.index_to_code: - old_name = f"next___{var}_index__" - if old_name in params: - params[f"next_{var}"] = params.pop(old_name) - return params - - def params_to_internal_params(self, params: ParamsDict) -> ParamsDict: - """Convert parameters from external to internal representation. - - If a state has been converted, the name of its corresponding next function must - be changed from `next_{var}` to `next___{var}_index__`. - - """ - internal = params.copy() - for var in self.index_to_code: - old_name = f"next_{var}" - if old_name in internal: - internal[f"next___{var}_index__"] = internal.pop(old_name) - return internal - - def internal_vars_to_vars(self, internal: dict[str, Array]) -> dict[str, Array]: - """Convert discrete variables from internal to external representation. - - If a variable has been converted, the name of its corresponding index function - must be changed from `___{var}_index__` to `{var}`, and the values of the - variable must be converted from indices to codes. - - """ - variables = internal.copy() - for var, index_to_code in self.index_to_code.items(): - old_name = f"__{var}_index__" - if old_name in internal: - variables[var] = index_to_code(variables.pop(old_name)) - return variables - - def vars_to_internal_vars(self, variables: dict[str, Array]) -> dict[str, Array]: - """Convert discrete variables from external to internal representation. - - If a variable has been converted, the name of its corresponding index function - must be changed from `{var}` to `___{var}_index__`, and the values of the - variable must be converted from codes to indices. - - """ - internal = variables.copy() - for var, code_to_index in self.code_to_index.items(): - if var in variables: - internal[f"__{var}_index__"] = code_to_index(internal.pop(var)) - return internal - - -def convert_arbitrary_codes_to_array_indices( - model: Model, -) -> tuple[Model, DiscreteGridConverter]: - """Update the user model to ensure that discrete variables have index codes. - - For each discrete variable with non-index codes, we: - - 1. Remove the variable from the states or choices dictionary - 2. Replace it with a new state or choice with array index codes - 3. Add updated next functions (if the variable was a state variable) - 4. Add a function that maps the array index codes to the original codes - - Args: - model: The model as provided by the user. - - Returns: - - The model with all discrete variables having index codes. - - A converter that can be used to convert between the internal and external - representation of the model. - - """ - gridspecs = get_gridspecs(model) - - non_index_discrete_vars = _get_discrete_vars_with_non_index_codes(model) - - # fast path - if not non_index_discrete_vars: - return model, DiscreteGridConverter() - - functions = model.functions.copy() - states = model.states.copy() - choices = model.choices.copy() - - # Update grids - # ---------------------------------------------------------------------------------- - for var in non_index_discrete_vars: - grid: DiscreteGrid = gridspecs[var] # type: ignore[assignment] - index_category_class = make_dataclass( - grid.__str__(), - [(f"__{name}_index__", int, i) for i, name in enumerate(grid.categories)], - ) - index_grid = DiscreteGrid(index_category_class) - - if var in model.states: - states.pop(var) - states[f"__{var}_index__"] = index_grid - else: - choices.pop(var) - choices[f"__{var}_index__"] = index_grid - - # Update next functions - # ---------------------------------------------------------------------------------- - non_index_states = [s for s in model.states if s in non_index_discrete_vars] - - for var in non_index_states: - functions[f"next___{var}_index__"] = _get_next_index_func( - functions.pop(f"next_{var}"), - codes_array=gridspecs[var].to_jax(), - name=var, - ) - - # Add index to code functions - # ---------------------------------------------------------------------------------- - index_to_code_funcs = { - var: _get_index_to_code_func(gridspecs[var].to_jax(), name=var) - for var in non_index_discrete_vars - } - functions = functions | index_to_code_funcs - - # Create code to index functions for converter - # ---------------------------------------------------------------------------------- - code_to_index_funcs = { - var: _get_code_to_index_func(gridspecs[var].to_jax(), name=var) - for var in non_index_discrete_vars - } - - discrete_grid_converter = DiscreteGridConverter( - index_to_code=index_to_code_funcs, - code_to_index=code_to_index_funcs, - ) - - new_model = model.replace( - states=states, - choices=choices, - functions=functions, - ) - return new_model, discrete_grid_converter - - -def _get_discrete_vars_with_non_index_codes(model: Model) -> list[str]: - """Get discrete variables with non-index codes. - - Collect all discrete variables with codes that do not correspond to indices. - - """ - gridspecs = get_gridspecs(model) - discrete_vars = [] - for name, spec in gridspecs.items(): - if isinstance(spec, DiscreteGrid) and list(spec.codes) != list( - range(len(spec.codes)) - ): - discrete_vars.append(name) - return discrete_vars - - -F = TypeVar("F", bound=Callable[[tuple[Array, ...]], Array]) - - -def _get_next_index_func(next_func: F, codes_array: Array, name: str) -> F: - """Get next function for index state variable. - - Args: - next_func: The next function for the state variable. - codes_array: An array of codes. - name: The name of the state variable. - - Returns: - A next function corresponding to the index version of the state variable. - - """ - code_to_index = _get_code_to_index_func(codes_array, name) - - @functools.wraps(next_func) - def next_index_func(*args, **kwargs) -> Array: - next_state = next_func(*args, **kwargs) - return code_to_index(next_state) - - return cast(F, next_index_func) - - -def _get_index_to_code_func(codes_array: Array, name: str) -> Callable[[Array], Array]: - """Get function mapping from index to code. - - Args: - codes_array: An array of codes. - name: The name of resulting function argument. - - Returns: - A function mapping an array with indices corresponding to codes_array to the - corresponding codes. - - """ - arg_name = f"__{name}_index__" - - @with_signature(args=[arg_name]) - def func(*args, **kwargs): - kwargs = all_as_kwargs(args, kwargs, arg_names=[arg_name]) - index = kwargs[arg_name] - return codes_array[index] - - return func - - -def _get_code_to_index_func(codes_array: Array, name: str) -> Callable[[Array], Array]: - """Get function mapping from code to index. - - Args: - codes_array: An array of codes. - name: The name of resulting function argument. - - Returns: - A function mapping an array with values in codes_array to their corresponding - indices. - - """ - sorted_indices = jnp.argsort(codes_array) - sorted_codes = codes_array[sorted_indices] - - @with_signature(args=[name]) - def code_to_index(*args, **kwargs): - kwargs = all_as_kwargs(args, kwargs, arg_names=[name]) - data = kwargs[name] - indices_in_sorted = jnp.searchsorted(sorted_codes, data) - return sorted_indices[indices_in_sorted] - - return code_to_index diff --git a/src/lcm/input_processing/process_model.py b/src/lcm/input_processing/process_model.py index f67420d..179c626 100644 --- a/src/lcm/input_processing/process_model.py +++ b/src/lcm/input_processing/process_model.py @@ -9,9 +9,6 @@ from lcm.functools import all_as_args, all_as_kwargs from lcm.input_processing.create_params_template import create_params_template -from lcm.input_processing.discrete_grid_conversion import ( - convert_arbitrary_codes_to_array_indices, -) from lcm.input_processing.util import ( get_function_info, get_grids, @@ -39,21 +36,18 @@ def process_model(model: Model) -> InternalModel: The processed model. """ - tmp_model, converter = convert_arbitrary_codes_to_array_indices(model) - - params = create_params_template(tmp_model) + params = create_params_template(model) return InternalModel( - grids=get_grids(tmp_model), - gridspecs=get_gridspecs(tmp_model), - variable_info=get_variable_info(tmp_model), - functions=_get_internal_functions(tmp_model, params=params), - function_info=get_function_info(tmp_model), + grids=get_grids(model), + gridspecs=get_gridspecs(model), + variable_info=get_variable_info(model), + functions=_get_internal_functions(model, params=params), + function_info=get_function_info(model), params=params, - discrete_grid_converter=converter, # currently no additive utility shocks are supported random_utility_shocks=ShockType.NONE, - n_periods=tmp_model.n_periods, + n_periods=model.n_periods, ) diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 69ecaf5..9e71d81 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -5,7 +5,6 @@ from jax import Array from lcm.grids import ContinuousGrid, DiscreteGrid, Grid -from lcm.input_processing import DiscreteGridConverter from lcm.typing import ParamsDict, ShockType @@ -94,8 +93,6 @@ class InternalModel: True if the function has the corresponding property. The columns are: is_filter, is_constraint, is_next. params: Dict of model parameters. - discrete_grid_converter: Helps with converting between internal and external - representations of model variables and associated parameters. n_periods: Number of periods. random_utility_shocks: Type of random utility shocks. @@ -107,7 +104,6 @@ class InternalModel: functions: dict[str, Callable] function_info: pd.DataFrame params: ParamsDict - discrete_grid_converter: DiscreteGridConverter n_periods: int # Not properly processed yet random_utility_shocks: ShockType diff --git a/src/lcm/simulate.py b/src/lcm/simulate.py index aab880b..561cfce 100644 --- a/src/lcm/simulate.py +++ b/src/lcm/simulate.py @@ -9,7 +9,6 @@ from lcm.argmax import argmax, segment_argmax from lcm.dispatchers import spacemap, vmap_1d -from lcm.input_processing import DiscreteGridConverter from lcm.interfaces import InternalModel, Space @@ -21,7 +20,6 @@ def simulate( compute_ccv_policy_functions, model: InternalModel, next_state, - discrete_grid_converter: DiscreteGridConverter, logger, solve_model=None, vf_arr_list=None, @@ -45,8 +43,6 @@ def simulate( state and choice variables. For stochastic variables, it returns a random draw from the distribution of the next state. model (Model): Model instance. - discrete_grid_converter (DiscreteGridConverter): Converter for discrete - variables and parameters between external and internal representation. logger (logging.Logger): Logger that logs to stdout. solve_model (callable): Function that solves the model. Is only required if vf_arr_list is not provided. @@ -71,8 +67,6 @@ def simulate( # will do it. vf_arr_list = solve_model(params) - internal_params = discrete_grid_converter.params_to_internal_params(params) - logger.info("Starting simulation") # Update the vf_arr_list @@ -97,7 +91,7 @@ def simulate( sparse_choice_variables = model.variable_info.query("is_choice & is_sparse").index # The following variables are updated during the forward simulation - states = discrete_grid_converter.vars_to_internal_vars(initial_states) + states = initial_states key = jax.random.PRNGKey(seed=seed) # Forward simulation @@ -140,7 +134,7 @@ def simulate( continuous_choice_grids=continuous_choice_grids[period], vf_arr=vf_arr_list[period], state_indexers=state_indexers[period], - params=internal_params, + params=params, ) # Get optimal discrete choice given the optimal conditional continuous choices @@ -203,7 +197,7 @@ def simulate( **states, **choices, _period=jnp.repeat(period, n_initial_states), - params=internal_params, + params=params, keys=sim_keys, ) @@ -213,14 +207,14 @@ def simulate( logger.info("Period: %s", period) - processed = _process_simulated_data(_simulation_results, discrete_grid_converter) + processed = _process_simulated_data(_simulation_results) if additional_targets is not None: calculated_targets = _compute_targets( processed, targets=additional_targets, model_functions=model.functions, - params=internal_params, + params=params, ) processed = {**processed, **calculated_targets} @@ -339,23 +333,20 @@ def _compute_targets(processed_results, targets, model_functions, params): return target_func(params=params, **kwargs) -def _process_simulated_data(results, discrete_grid_converter): +def _process_simulated_data(results): """Process and flatten the simulation results. This function produces a dict of arrays for each var with dimension (n_periods * n_initial_states,). The arrays are flattened, so that the resulting dictionary has a one-dimensional array for each variable. The length of this array is the number of periods times the number of initial states. The order of array elements is given by - an outer level of periods and an inner level of initial states ids. Discrete - variables with an internal representation are converted to their external - representation. + an outer level of periods and an inner level of initial states ids. Args: results (list): List of dicts with simulation results. Each dict contains the value, choices, and states for one period. Choices and states are stored in a nested dictionary. - discrete_grid_converter (DiscreteGridConverter): Converter for discrete grids. Returns: dict: Dict with processed simulation results. The keys are the variable names @@ -375,7 +366,7 @@ def _process_simulated_data(results, discrete_grid_converter): out = {key: jnp.concatenate(values) for key, values in dict_of_lists.items()} out["_period"] = jnp.repeat(jnp.arange(n_periods), n_initial_states) - return discrete_grid_converter.internal_vars_to_vars(out) + return out # ====================================================================================== diff --git a/src/lcm/solve_brute.py b/src/lcm/solve_brute.py index 978627d..a7f1ab8 100644 --- a/src/lcm/solve_brute.py +++ b/src/lcm/solve_brute.py @@ -10,7 +10,6 @@ def solve( continuous_choice_grids, compute_ccv_functions, emax_calculators, - discrete_grid_converter, logger, ): """Solve a model by brute force. @@ -43,16 +42,12 @@ def solve( emax_calculators (list): List of functions that take continuation values for combinations of states and discrete choices and calculate the expected maximum over all discrete choices of a given state. - discrete_grid_converter (DiscreteGridConverter): Converter for discrete - variables and parameters between external and internal representation. logger (logging.Logger): Logger that logs to stdout. Returns: list: List with one value function array per period. """ - internal_params = discrete_grid_converter.params_to_internal_params(params) - # extract information n_periods = len(state_choice_spaces) reversed_solution = [] @@ -69,12 +64,12 @@ def solve( continuous_choice_grids=continuous_choice_grids[period], vf_arr=vf_arr, state_indexers=state_indexers[period], - params=internal_params, + params=params, ) # solve discrete problem by calculating expected maximum over discrete choices calculate_emax = emax_calculators[period] - vf_arr = calculate_emax(conditional_continuation_values, params=internal_params) + vf_arr = calculate_emax(conditional_continuation_values, params=params) reversed_solution.append(vf_arr) logger.info("Period: %s", period) diff --git a/tests/input_processing/test_discrete_state_conversion.py b/tests/input_processing/test_discrete_state_conversion.py deleted file mode 100644 index eecb273..0000000 --- a/tests/input_processing/test_discrete_state_conversion.py +++ /dev/null @@ -1,155 +0,0 @@ -from dataclasses import dataclass -from typing import Any - -import jax.numpy as jnp -import pytest -from numpy.testing import assert_array_equal - -from lcm import DiscreteGrid -from lcm.input_processing.discrete_grid_conversion import ( - DiscreteGridConverter, - _get_code_to_index_func, - _get_discrete_vars_with_non_index_codes, - _get_index_to_code_func, - _get_next_index_func, - convert_arbitrary_codes_to_array_indices, -) - - -@dataclass -class ModelMock: - """A model mock for testing the process_model function. - - This dataclass has the same attributes as the Model dataclass, but does not perform - any checks, which helps us to test the process_model function in isolation. - - """ - - n_periods: int - functions: dict[str, Any] - choices: dict[str, Any] - states: dict[str, Any] - - -@pytest.fixture -def model(binary_category_class): - def next_c(a, b): - return a + b - - return ModelMock( - n_periods=2, - functions={ - "next_c": next_c, - }, - choices={ - "a": DiscreteGrid(binary_category_class), - }, - states={ - "c": DiscreteGrid(binary_category_class), - }, - ) - - -@pytest.fixture -def discrete_state_converter_kwargs(): - return { - "index_to_code": {"c": _get_index_to_code_func(jnp.array([1, 0]), name="c")}, - "code_to_index": {"c": _get_code_to_index_func(jnp.array([1, 0]), name="c")}, - } - - -def test_discrete_state_converter_internal_to_params(discrete_state_converter_kwargs): - expected = { - "next_c": 1, - } - internal_params = { - "next___c_index__": 1, - } - converter = DiscreteGridConverter(**discrete_state_converter_kwargs) - assert converter.internal_params_to_params(internal_params) == expected - - -def test_discrete_state_converter_params_to_internal(discrete_state_converter_kwargs): - expected = { - "next___c_index__": 1, - } - params = { - "next_c": 1, - } - converter = DiscreteGridConverter(**discrete_state_converter_kwargs) - assert converter.params_to_internal_params(params) == expected - - -def test_discrete_state_converter_internal_to_discrete_vars( - discrete_state_converter_kwargs, -): - expected = jnp.array([1, 0]) - internal_states = { - "__c_index__": jnp.array([0, 1]), - } - converter = DiscreteGridConverter(**discrete_state_converter_kwargs) - assert_array_equal(converter.internal_vars_to_vars(internal_states)["c"], expected) - - -def test_discrete_state_converter_discrete_vars_to_internal( - discrete_state_converter_kwargs, -): - expected = jnp.array([0, 1]) - states = { - "c": jnp.array([1, 0]), - } - converter = DiscreteGridConverter(**discrete_state_converter_kwargs) - assert_array_equal(converter.vars_to_internal_vars(states)["__c_index__"], expected) - - -def test_discrete_state_converter_raises_error_if_keys_dont_match(): - index_to_code = {"a": 0} - code_to_index = {"b": 0} - with pytest.raises( - ValueError, - match="The keys of index_to_code and code_to_index must be the same.", - ): - DiscreteGridConverter(index_to_code=index_to_code, code_to_index=code_to_index) - - -def test_get_index_to_label_func(): - codes_array = jnp.array([1, 0]) - got = _get_index_to_code_func(codes_array, name="foo") - assert got(__foo_index__=0) == 1 - assert got(1) == 0 - - -def test_get_code_to_index_func(): - codes_array = jnp.array([1, 0]) - got = _get_code_to_index_func(codes_array, name="foo") - assert_array_equal(got(foo=codes_array), jnp.arange(2)) - - -def test_get_discrete_vars_with_non_index_codes(model): - got = _get_discrete_vars_with_non_index_codes(model) - assert got == ["c"] - - -def test_convert_discrete_codes_to_indices(model): - # add replace method to model mock - model.replace = lambda **kwargs: ModelMock(**kwargs, n_periods=model.n_periods) - - got, _ = convert_arbitrary_codes_to_array_indices(model) - - assert "c" not in got.states - assert "__c_index__" in got.states - assert "c" in got.functions - assert got.states["__c_index__"].categories == ("__cat0_index__", "__cat1_index__") - assert got.states["__c_index__"].codes == (0, 1) - assert got.functions["c"](0) == 1 - assert got.functions["c"](1) == 0 - - -def test_get_next_index_func(): - got_func = _get_next_index_func( - next_func=lambda wealth, working: jnp.clip(wealth + working, 1, 3), - codes_array=jnp.array([1, 2, 3]), - name="wealth", - ) - got = got_func(wealth=jnp.array([1, 2]), working=jnp.array([0, 1])) - assert_array_equal(got, jnp.array([0, 2])) diff --git a/tests/test_model_functions.py b/tests/test_model_functions.py index ad1b1ce..bd9a933 100644 --- a/tests/test_model_functions.py +++ b/tests/test_model_functions.py @@ -34,7 +34,6 @@ def h(): functions={"f": f, "g": g, "h": h}, function_info=function_info, params=None, - discrete_grid_converter=None, random_utility_shocks=None, n_periods=None, ) diff --git a/tests/test_next_state.py b/tests/test_next_state.py index 416c639..7bdcbb2 100644 --- a/tests/test_next_state.py +++ b/tests/test_next_state.py @@ -68,7 +68,6 @@ def f_weight_b(state): # noqa: ARG001 gridspecs=None, variable_info=None, params=None, - discrete_grid_converter=None, random_utility_shocks=None, n_periods=1, ) diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 7f1de65..f9fd516 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -9,7 +9,7 @@ create_compute_conditional_continuation_policy, get_lcm_function, ) -from lcm.input_processing import DiscreteGridConverter, process_model +from lcm.input_processing import process_model from lcm.logging import get_logger from lcm.model_functions import get_utility_and_feasibility_function from lcm.next_state import _get_next_state_function_simulation @@ -75,7 +75,6 @@ def simulate_inputs(): "compute_ccv_policy_functions": compute_ccv_policy_functions, "model": model, "next_state": _get_next_state_function_simulation(model), - "discrete_grid_converter": DiscreteGridConverter(), } @@ -169,13 +168,7 @@ def test_simulate_using_get_lcm_function( assert (res.loc[period]["value"].diff()[1:] >= 0).all() -# ====================================================================================== -# Test simulation works correctly with discrete grid conversion -# ====================================================================================== - - -def test_simulate_with_discrete_grid_conversion(): - """Test that the simulation works correctly with discrete grid conversion.""" +def test_simulate_with_only_discrete_choices(): model = get_model_config("iskhakov_et_al_2017_discrete", n_periods=2) params = get_params(wage=1.5, beta=1, interest_rate=0) @@ -377,9 +370,7 @@ def test_process_simulated_data(): "b": jnp.array([-1, -2, -3, -4]), } - got = _process_simulated_data( - simulated, discrete_grid_converter=DiscreteGridConverter() - ) + got = _process_simulated_data(simulated) assert tree_equal(expected, got) diff --git a/tests/test_solve_brute.py b/tests/test_solve_brute.py index 1cff2ec..7fc11f2 100644 --- a/tests/test_solve_brute.py +++ b/tests/test_solve_brute.py @@ -4,7 +4,6 @@ from numpy.testing import assert_array_almost_equal as aaae from lcm.entry_point import create_compute_conditional_continuation_value -from lcm.input_processing import DiscreteGridConverter from lcm.interfaces import Space from lcm.logging import get_logger from lcm.solve_brute import solve, solve_continuous_problem @@ -113,7 +112,6 @@ def calculate_emax(values, params): # noqa: ARG001 continuous_choice_grids=continuous_choice_grids, compute_ccv_functions=utility_and_feasibility_functions, emax_calculators=emax_calculators, - discrete_grid_converter=DiscreteGridConverter(), logger=get_logger(debug_mode=False), ) diff --git a/tests/test_state_space.py b/tests/test_state_space.py index 8449e76..cd0fb2b 100644 --- a/tests/test_state_space.py +++ b/tests/test_state_space.py @@ -69,7 +69,6 @@ def absorbing_retirement_filter(retirement, lagged_retirement): functions=functions, function_info=function_info, params=None, - discrete_grid_converter=None, random_utility_shocks=None, n_periods=100, )