From d126dbaa0a15f4085b66b608ff0d747a0e513e58 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 9 Sep 2024 15:49:43 +0200 Subject: [PATCH] Integrate comments from review --- .pre-commit-config.yaml | 7 +++++ src/lcm/create_params_template.py | 12 ++++---- src/lcm/discrete_problem.py | 8 +++--- src/lcm/entry_point.py | 39 ++++++++++++------------- src/lcm/interfaces.py | 48 +++++++++++++++---------------- src/lcm/process_model.py | 5 ++-- src/lcm/simulate.py | 6 ++-- src/lcm/typing.py | 2 +- tests/test_discrete_problem.py | 6 ++-- tests/test_model_functions.py | 2 +- tests/test_next_state.py | 2 +- tests/test_state_space.py | 2 +- 12 files changed, 73 insertions(+), 66 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f7d8110..9df51b4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -74,6 +74,13 @@ repos: - --wrap - '88' files: (README\.md) + - repo: https://github.com/kynan/nbstripout + rev: 0.7.1 + hooks: + - id: nbstripout + args: + - --drop-empty-cells + - --keep-output - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.11.2 hooks: diff --git a/src/lcm/create_params_template.py b/src/lcm/create_params_template.py index b0e1848..10915c7 100644 --- a/src/lcm/create_params_template.py +++ b/src/lcm/create_params_template.py @@ -14,7 +14,7 @@ def create_params_template( user_model: Model, variable_info: pd.DataFrame, grids: dict[str, Array], - default_params: dict[str, int | float] | None = None, + default_params: dict[str, float] | None = None, ) -> ParamsDict: """Create parameter template from a model specification. @@ -28,7 +28,7 @@ def create_params_template( np.nan} for beta-delta discounting. Returns: - ParamsDict: A nested dictionary of model parameters. + A nested dictionary of model parameters. """ if default_params is None: @@ -63,8 +63,8 @@ def _create_function_params(user_model: Model) -> dict[str, dict[str, float]]: user_model: The model as provided by the user. Returns: - dict: A dictionary for each model function, containing a parameters required in - the model functions, initialized with jnp.nan. + A dictionary for each model function, containing a parameters required in the + model functions, initialized with jnp.nan. """ # Collect all model variables, that includes choices, states, the period, and @@ -103,8 +103,8 @@ def _create_stochastic_transition_params( grids: A dictionary of grids consistent with user_model. Returns: - dict: A dictionary of parameters required for stochastic transitions, - initialized with jnp.nan matrices of the correct dimensions. + A dictionary of parameters required for stochastic transitions, initialized with + jnp.nan matrices of the correct dimensions. """ stochastic_variables = variable_info.query("is_stochastic").index.tolist() diff --git a/src/lcm/discrete_problem.py b/src/lcm/discrete_problem.py index 6cc664d..13bc65f 100644 --- a/src/lcm/discrete_problem.py +++ b/src/lcm/discrete_problem.py @@ -26,12 +26,12 @@ from jax import Array from jax.ops import segment_max -from lcm.typing import ParamsDict, SegmentInfo, Shock +from lcm.typing import ParamsDict, SegmentInfo, ShockType def get_solve_discrete_problem( *, - random_utility_shock_type: Shock, + random_utility_shock_type: ShockType, variable_info: pd.DataFrame, is_last_period: bool, choice_segments: SegmentInfo | None, @@ -61,9 +61,9 @@ def get_solve_discrete_problem( choice_axes = _determine_dense_discrete_choice_axes(variable_info) - if random_utility_shock_type == Shock.NONE: + if random_utility_shock_type == ShockType.NONE: func = _solve_discrete_problem_no_shocks - elif random_utility_shock_type == Shock.EXTREME_VALUE: + elif random_utility_shock_type == ShockType.EXTREME_VALUE: raise NotImplementedError("Extreme value shocks are not yet implemented.") else: raise ValueError(f"Invalid shock_type: {random_utility_shock_type}.") diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index de29992..b27a3ed 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -1,6 +1,7 @@ import functools +from collections.abc import Callable from functools import partial -from typing import Literal +from typing import Literal, cast import jax import jax.numpy as jnp @@ -19,6 +20,7 @@ from lcm.simulate import simulate from lcm.solve_brute import solve from lcm.state_space import create_state_choice_space +from lcm.typing import ParamsDict def get_lcm_function( @@ -28,37 +30,34 @@ def get_lcm_function( debug_mode: bool = True, jit: bool = True, interpolation_options: dict | None = None, -): +) -> tuple[Callable, ParamsDict]: """Entry point for users to get high level functions generated by lcm. + Return the function to solve and/or simulate a model along with a template for the + parameters. + Advanced users might want to use lower level functions instead, but can read the source code of this function to see how the lower level components are meant to be used. Notes: ----- - - Further targets could be "likelihood" or "simulate" - - We might need additional arguments such as solver_options that we want to take - separate from a model specification. - - create_params needs to work with a processed_model instead of a user model. - - currently all the preparations are hardcoded to generate the arguments needed - by solve_brute. In the long run, this needs to inspect the signature of the - solver, or generate only what is needed using dags. - - there is a hack to make the state_indexers empty in the last period which needs + - There is a hack to make the state_indexers empty in the last period which needs to be replaced by a better solution, when we want to allow for bequest motives. Args: - model (dict): User model specification. - targets (str or iterable): The requested function types. Currently only - "solve", "simulate" and "solve_and_simulate" are supported. - debug_mode (bool): Whether to log debug messages. - jit (bool): Whether to jit the returned function. - interpolation_options (dict): Dictionary of keyword arguments for interpolation + model: User model specification. + targets: The requested function types. Currently only "solve", "simulate" and + "solve_and_simulate" are supported. + debug_mode: Whether to log debug messages. + jit: Whether to jit the returned function. + interpolation_options: Dictionary of keyword arguments for interpolation via map_coordinates. If None, the default options are used. Returns: - callable: A function that takes params and returns the requested targets. - dict: A parameter dict where all parameter values are initialized to NaN. + - A function that takes params (and possibly other arguments, such as initial + states in the simulate case) and returns the requested targets. + - A parameter dictionary where all parameter values are initialized to NaN. """ # ================================================================================== @@ -154,7 +153,7 @@ def get_lcm_function( # create list of emax_calculators # ============================================================================== calculator = get_solve_discrete_problem( - random_utility_shock_type=_mod.shocks, + random_utility_shock_type=_mod.random_utility_shocks, variable_info=_mod.variable_info, is_last_period=is_last_period, choice_segments=choice_segments[period], @@ -195,7 +194,7 @@ def get_lcm_function( elif targets == "solve_and_simulate": _target = partial(simulate_model, solve_model=solve_model) - return _target, _mod.params + return cast(Callable, _target), _mod.params def create_compute_conditional_continuation_value( diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 1208a2e..9e71d81 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -5,7 +5,7 @@ from jax import Array from lcm.grids import ContinuousGrid, DiscreteGrid, Grid -from lcm.typing import ParamsDict, Shock +from lcm.typing import ParamsDict, ShockType @dataclass(frozen=True) @@ -73,28 +73,28 @@ class InternalModel: """Internal representation of a user model. Attributes: - grids (dict): Dictionary that maps names of model variables to grids of feasible - values for that variable. - gridspecs (dict): Dictionary that maps names of model variables to - specifications from which grids of feasible values can be built. - variable_info (pd.DataFrame): A table with information about all variables in - the model. The index contains the name of a model variable. The columns are - booleans that are True if the variable has the corresponding property. The - columns are: is_state, is_choice, is_continuous, is_discrete, is_sparse, - is_dense. - functions (dict): Dictionary that maps names of functions to functions. The - functions differ from the user functions in that they all except the - filter functions take ``params`` as keyword argument. If the original - function depended on model parameters, those are automatically extracted - from ``params`` and passed to the original function. Otherwise, the - ``params`` argument is simply ignored. - function_info (pd.DataFrame): A table with information about all functions in - the model. The index contains the name of a function. The columns are - booleans that are True if the function has the corresponding property. The - columns are: is_filter, is_constraint, is_next. - params (dict): Dict of model parameters. - n_periods (int): Number of periods. - shocks (Shock): Type of shocks. + grids: Dictionary that maps names of model variables to grids of feasible values + for that variable. + gridspecs: Dictionary that maps names of model variables to specifications from + which grids of feasible values can be built. + variable_info: A table with information about all variables in the model. The + index contains the name of a model variable. The columns are booleans that + are True if the variable has the corresponding property. The columns are: + is_state, is_choice, is_continuous, is_discrete, is_sparse, columns are: + is_state, is_choice, is_continuous, is_discrete, is_sparse, is_dense. + functions: Dictionary that maps names of functions to functions. The functions + differ from the user functions in that they all except the filter functions + take ``params`` as keyword argument. If the original function depended on + model parameters, those are automatically extracted from ``params`` and + passed to the original function. Otherwise, the ``params`` argument is + simply ignored. + function_info: A table with information about all functions in the model. The + index contains the name of a function. The columns are booleans that are + True if the function has the corresponding property. The columns are: + is_filter, is_constraint, is_next. + params: Dict of model parameters. + n_periods: Number of periods. + random_utility_shocks: Type of random utility shocks. """ @@ -106,4 +106,4 @@ class InternalModel: params: ParamsDict n_periods: int # Not properly processed yet - shocks: Shock + random_utility_shocks: ShockType diff --git a/src/lcm/process_model.py b/src/lcm/process_model.py index de55cd9..14e6c34 100644 --- a/src/lcm/process_model.py +++ b/src/lcm/process_model.py @@ -16,7 +16,7 @@ ) from lcm.interfaces import InternalModel from lcm.model import Model -from lcm.typing import ParamsDict, Shock +from lcm.typing import ParamsDict, ShockType def process_model(user_model: Model) -> InternalModel: @@ -66,7 +66,8 @@ def process_model(user_model: Model) -> InternalModel: functions=functions, function_info=function_info, params=params, - shocks=Shock.NONE, # currently no additive utility shocks are supported + # currently no additive utility shocks are supported + random_utility_shocks=ShockType.NONE, n_periods=user_model.n_periods, ) diff --git a/src/lcm/simulate.py b/src/lcm/simulate.py index e407e26..f2cd773 100644 --- a/src/lcm/simulate.py +++ b/src/lcm/simulate.py @@ -14,12 +14,12 @@ def simulate( params, + initial_states, state_indexers, continuous_choice_grids, compute_ccv_policy_functions, model, next_state, - initial_states, logger, solve_model=None, vf_arr_list=None, @@ -30,6 +30,8 @@ def simulate( Args: params (dict): Dict of model parameters. + initial_states (list): List of initial states to start from. Typically from the + observed dataset. state_indexers (list): List of dicts of length n_periods. Each dict contains one or several state indexers. continuous_choice_grids (list): List of dicts of length n_periods. Each dict @@ -41,8 +43,6 @@ def simulate( state and choice variables. For stochastic variables, it returns a random draw from the distribution of the next state. model (Model): Model instance. - initial_states (list): List of initial states to start from. Typically from the - observed dataset. logger (logging.Logger): Logger that logs to stdout. solve_model (callable): Function that solves the model. Is only required if vf_arr_list is not provided. diff --git a/src/lcm/typing.py b/src/lcm/typing.py index 1e1776d..afef46d 100644 --- a/src/lcm/typing.py +++ b/src/lcm/typing.py @@ -10,7 +10,7 @@ ParamsDict = dict[str, Any] -class Shock(Enum): +class ShockType(Enum): """Type of shocks.""" EXTREME_VALUE = "extreme_value" diff --git a/tests/test_discrete_problem.py b/tests/test_discrete_problem.py index cfaaad4..a1d0767 100644 --- a/tests/test_discrete_problem.py +++ b/tests/test_discrete_problem.py @@ -15,7 +15,7 @@ _solve_discrete_problem_no_shocks, get_solve_discrete_problem, ) -from lcm.typing import Shock +from lcm.typing import ShockType @pytest.fixture @@ -50,7 +50,7 @@ def test_aggregation_without_shocks(cc_values, segment_info, collapse, n_extra_a ) solve_discrete_problem = get_solve_discrete_problem( - random_utility_shock_type=Shock.NONE, + random_utility_shock_type=ShockType.NONE, variable_info=var_info, is_last_period=False, choice_segments=segment_info, @@ -155,7 +155,7 @@ def test_get_solve_discrete_problem_illustrative(): ) # leads to choice_axes = [1] solve_discrete_problem = get_solve_discrete_problem( - random_utility_shock_type=Shock.NONE, + random_utility_shock_type=ShockType.NONE, variable_info=variable_info, is_last_period=False, choice_segments=None, diff --git a/tests/test_model_functions.py b/tests/test_model_functions.py index 85a0228..63e1dc0 100644 --- a/tests/test_model_functions.py +++ b/tests/test_model_functions.py @@ -34,7 +34,7 @@ def h(): functions={"f": f, "g": g, "h": h}, function_info=function_info, params=None, - shocks=None, + random_utility_shocks=None, n_periods=None, ) combined_constraint = get_combined_constraint(model) diff --git a/tests/test_next_state.py b/tests/test_next_state.py index 236e11e..e723bb6 100644 --- a/tests/test_next_state.py +++ b/tests/test_next_state.py @@ -68,7 +68,7 @@ def f_weight_b(state): # noqa: ARG001 gridspecs=None, variable_info=None, params=None, - shocks=None, + random_utility_shocks=None, n_periods=1, ) diff --git a/tests/test_state_space.py b/tests/test_state_space.py index 1867107..de7677d 100644 --- a/tests/test_state_space.py +++ b/tests/test_state_space.py @@ -68,7 +68,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): variable_info=None, functions=functions, function_info=function_info, - shocks=None, + random_utility_shocks=None, n_periods=100, params={}, )