Skip to content

Commit

Permalink
Integrate comments from review
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 9, 2024
1 parent 216a583 commit d126dba
Show file tree
Hide file tree
Showing 12 changed files with 73 additions and 66 deletions.
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ repos:
- --wrap
- '88'
files: (README\.md)
- repo: https://github.com/kynan/nbstripout
rev: 0.7.1
hooks:
- id: nbstripout
args:
- --drop-empty-cells
- --keep-output
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
hooks:
Expand Down
12 changes: 6 additions & 6 deletions src/lcm/create_params_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def create_params_template(
user_model: Model,
variable_info: pd.DataFrame,
grids: dict[str, Array],
default_params: dict[str, int | float] | None = None,
default_params: dict[str, float] | None = None,
) -> ParamsDict:
"""Create parameter template from a model specification.
Expand All @@ -28,7 +28,7 @@ def create_params_template(
np.nan} for beta-delta discounting.
Returns:
ParamsDict: A nested dictionary of model parameters.
A nested dictionary of model parameters.
"""
if default_params is None:
Expand Down Expand Up @@ -63,8 +63,8 @@ def _create_function_params(user_model: Model) -> dict[str, dict[str, float]]:
user_model: The model as provided by the user.
Returns:
dict: A dictionary for each model function, containing a parameters required in
the model functions, initialized with jnp.nan.
A dictionary for each model function, containing a parameters required in the
model functions, initialized with jnp.nan.
"""
# Collect all model variables, that includes choices, states, the period, and
Expand Down Expand Up @@ -103,8 +103,8 @@ def _create_stochastic_transition_params(
grids: A dictionary of grids consistent with user_model.
Returns:
dict: A dictionary of parameters required for stochastic transitions,
initialized with jnp.nan matrices of the correct dimensions.
A dictionary of parameters required for stochastic transitions, initialized with
jnp.nan matrices of the correct dimensions.
"""
stochastic_variables = variable_info.query("is_stochastic").index.tolist()
Expand Down
8 changes: 4 additions & 4 deletions src/lcm/discrete_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
from jax import Array
from jax.ops import segment_max

from lcm.typing import ParamsDict, SegmentInfo, Shock
from lcm.typing import ParamsDict, SegmentInfo, ShockType


def get_solve_discrete_problem(
*,
random_utility_shock_type: Shock,
random_utility_shock_type: ShockType,
variable_info: pd.DataFrame,
is_last_period: bool,
choice_segments: SegmentInfo | None,
Expand Down Expand Up @@ -61,9 +61,9 @@ def get_solve_discrete_problem(

choice_axes = _determine_dense_discrete_choice_axes(variable_info)

if random_utility_shock_type == Shock.NONE:
if random_utility_shock_type == ShockType.NONE:
func = _solve_discrete_problem_no_shocks
elif random_utility_shock_type == Shock.EXTREME_VALUE:
elif random_utility_shock_type == ShockType.EXTREME_VALUE:
raise NotImplementedError("Extreme value shocks are not yet implemented.")
else:
raise ValueError(f"Invalid shock_type: {random_utility_shock_type}.")
Expand Down
39 changes: 19 additions & 20 deletions src/lcm/entry_point.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
from collections.abc import Callable
from functools import partial
from typing import Literal
from typing import Literal, cast

import jax
import jax.numpy as jnp
Expand All @@ -19,6 +20,7 @@
from lcm.simulate import simulate
from lcm.solve_brute import solve
from lcm.state_space import create_state_choice_space
from lcm.typing import ParamsDict


def get_lcm_function(
Expand All @@ -28,37 +30,34 @@ def get_lcm_function(
debug_mode: bool = True,
jit: bool = True,
interpolation_options: dict | None = None,
):
) -> tuple[Callable, ParamsDict]:
"""Entry point for users to get high level functions generated by lcm.
Return the function to solve and/or simulate a model along with a template for the
parameters.
Advanced users might want to use lower level functions instead, but can read the
source code of this function to see how the lower level components are meant to be
used.
Notes:
-----
- Further targets could be "likelihood" or "simulate"
- We might need additional arguments such as solver_options that we want to take
separate from a model specification.
- create_params needs to work with a processed_model instead of a user model.
- currently all the preparations are hardcoded to generate the arguments needed
by solve_brute. In the long run, this needs to inspect the signature of the
solver, or generate only what is needed using dags.
- there is a hack to make the state_indexers empty in the last period which needs
- There is a hack to make the state_indexers empty in the last period which needs
to be replaced by a better solution, when we want to allow for bequest motives.
Args:
model (dict): User model specification.
targets (str or iterable): The requested function types. Currently only
"solve", "simulate" and "solve_and_simulate" are supported.
debug_mode (bool): Whether to log debug messages.
jit (bool): Whether to jit the returned function.
interpolation_options (dict): Dictionary of keyword arguments for interpolation
model: User model specification.
targets: The requested function types. Currently only "solve", "simulate" and
"solve_and_simulate" are supported.
debug_mode: Whether to log debug messages.
jit: Whether to jit the returned function.
interpolation_options: Dictionary of keyword arguments for interpolation
via map_coordinates. If None, the default options are used.
Returns:
callable: A function that takes params and returns the requested targets.
dict: A parameter dict where all parameter values are initialized to NaN.
- A function that takes params (and possibly other arguments, such as initial
states in the simulate case) and returns the requested targets.
- A parameter dictionary where all parameter values are initialized to NaN.
"""
# ==================================================================================
Expand Down Expand Up @@ -154,7 +153,7 @@ def get_lcm_function(
# create list of emax_calculators
# ==============================================================================
calculator = get_solve_discrete_problem(
random_utility_shock_type=_mod.shocks,
random_utility_shock_type=_mod.random_utility_shocks,
variable_info=_mod.variable_info,
is_last_period=is_last_period,
choice_segments=choice_segments[period],
Expand Down Expand Up @@ -195,7 +194,7 @@ def get_lcm_function(
elif targets == "solve_and_simulate":
_target = partial(simulate_model, solve_model=solve_model)

return _target, _mod.params
return cast(Callable, _target), _mod.params


def create_compute_conditional_continuation_value(
Expand Down
48 changes: 24 additions & 24 deletions src/lcm/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jax import Array

from lcm.grids import ContinuousGrid, DiscreteGrid, Grid
from lcm.typing import ParamsDict, Shock
from lcm.typing import ParamsDict, ShockType


@dataclass(frozen=True)
Expand Down Expand Up @@ -73,28 +73,28 @@ class InternalModel:
"""Internal representation of a user model.
Attributes:
grids (dict): Dictionary that maps names of model variables to grids of feasible
values for that variable.
gridspecs (dict): Dictionary that maps names of model variables to
specifications from which grids of feasible values can be built.
variable_info (pd.DataFrame): A table with information about all variables in
the model. The index contains the name of a model variable. The columns are
booleans that are True if the variable has the corresponding property. The
columns are: is_state, is_choice, is_continuous, is_discrete, is_sparse,
is_dense.
functions (dict): Dictionary that maps names of functions to functions. The
functions differ from the user functions in that they all except the
filter functions take ``params`` as keyword argument. If the original
function depended on model parameters, those are automatically extracted
from ``params`` and passed to the original function. Otherwise, the
``params`` argument is simply ignored.
function_info (pd.DataFrame): A table with information about all functions in
the model. The index contains the name of a function. The columns are
booleans that are True if the function has the corresponding property. The
columns are: is_filter, is_constraint, is_next.
params (dict): Dict of model parameters.
n_periods (int): Number of periods.
shocks (Shock): Type of shocks.
grids: Dictionary that maps names of model variables to grids of feasible values
for that variable.
gridspecs: Dictionary that maps names of model variables to specifications from
which grids of feasible values can be built.
variable_info: A table with information about all variables in the model. The
index contains the name of a model variable. The columns are booleans that
are True if the variable has the corresponding property. The columns are:
is_state, is_choice, is_continuous, is_discrete, is_sparse, columns are:
is_state, is_choice, is_continuous, is_discrete, is_sparse, is_dense.
functions: Dictionary that maps names of functions to functions. The functions
differ from the user functions in that they all except the filter functions
take ``params`` as keyword argument. If the original function depended on
model parameters, those are automatically extracted from ``params`` and
passed to the original function. Otherwise, the ``params`` argument is
simply ignored.
function_info: A table with information about all functions in the model. The
index contains the name of a function. The columns are booleans that are
True if the function has the corresponding property. The columns are:
is_filter, is_constraint, is_next.
params: Dict of model parameters.
n_periods: Number of periods.
random_utility_shocks: Type of random utility shocks.
"""

Expand All @@ -106,4 +106,4 @@ class InternalModel:
params: ParamsDict
n_periods: int
# Not properly processed yet
shocks: Shock
random_utility_shocks: ShockType
5 changes: 3 additions & 2 deletions src/lcm/process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from lcm.interfaces import InternalModel
from lcm.model import Model
from lcm.typing import ParamsDict, Shock
from lcm.typing import ParamsDict, ShockType


def process_model(user_model: Model) -> InternalModel:
Expand Down Expand Up @@ -66,7 +66,8 @@ def process_model(user_model: Model) -> InternalModel:
functions=functions,
function_info=function_info,
params=params,
shocks=Shock.NONE, # currently no additive utility shocks are supported
# currently no additive utility shocks are supported
random_utility_shocks=ShockType.NONE,
n_periods=user_model.n_periods,
)

Expand Down
6 changes: 3 additions & 3 deletions src/lcm/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

def simulate(
params,
initial_states,
state_indexers,
continuous_choice_grids,
compute_ccv_policy_functions,
model,
next_state,
initial_states,
logger,
solve_model=None,
vf_arr_list=None,
Expand All @@ -30,6 +30,8 @@ def simulate(
Args:
params (dict): Dict of model parameters.
initial_states (list): List of initial states to start from. Typically from the
observed dataset.
state_indexers (list): List of dicts of length n_periods. Each dict contains one
or several state indexers.
continuous_choice_grids (list): List of dicts of length n_periods. Each dict
Expand All @@ -41,8 +43,6 @@ def simulate(
state and choice variables. For stochastic variables, it returns a random
draw from the distribution of the next state.
model (Model): Model instance.
initial_states (list): List of initial states to start from. Typically from the
observed dataset.
logger (logging.Logger): Logger that logs to stdout.
solve_model (callable): Function that solves the model. Is only required if
vf_arr_list is not provided.
Expand Down
2 changes: 1 addition & 1 deletion src/lcm/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ParamsDict = dict[str, Any]


class Shock(Enum):
class ShockType(Enum):
"""Type of shocks."""

EXTREME_VALUE = "extreme_value"
Expand Down
6 changes: 3 additions & 3 deletions tests/test_discrete_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
_solve_discrete_problem_no_shocks,
get_solve_discrete_problem,
)
from lcm.typing import Shock
from lcm.typing import ShockType


@pytest.fixture
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_aggregation_without_shocks(cc_values, segment_info, collapse, n_extra_a
)

solve_discrete_problem = get_solve_discrete_problem(
random_utility_shock_type=Shock.NONE,
random_utility_shock_type=ShockType.NONE,
variable_info=var_info,
is_last_period=False,
choice_segments=segment_info,
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_get_solve_discrete_problem_illustrative():
) # leads to choice_axes = [1]

solve_discrete_problem = get_solve_discrete_problem(
random_utility_shock_type=Shock.NONE,
random_utility_shock_type=ShockType.NONE,
variable_info=variable_info,
is_last_period=False,
choice_segments=None,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def h():
functions={"f": f, "g": g, "h": h},
function_info=function_info,
params=None,
shocks=None,
random_utility_shocks=None,
n_periods=None,
)
combined_constraint = get_combined_constraint(model)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_next_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def f_weight_b(state): # noqa: ARG001
gridspecs=None,
variable_info=None,
params=None,
shocks=None,
random_utility_shocks=None,
n_periods=1,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_state_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement):
variable_info=None,
functions=functions,
function_info=function_info,
shocks=None,
random_utility_shocks=None,
n_periods=100,
params={},
)
Expand Down

0 comments on commit d126dba

Please sign in to comment.