Skip to content

Commit

Permalink
Replace Params type by much simpler ParamsDict type
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 5, 2024
1 parent 3cef907 commit 048ce87
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 17 deletions.
10 changes: 5 additions & 5 deletions src/lcm/create_params_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from jax import Array

from lcm.model import Model
from lcm.typing import Params
from lcm.typing import ParamsDict


def create_params_template(
user_model: Model,
variable_info: pd.DataFrame,
grids: dict[str, Array],
default_params: dict[str, int | float] | None = None,
) -> Params:
) -> ParamsDict:
"""Create parameter template from a model specification.
Args:
Expand All @@ -28,7 +28,7 @@ def create_params_template(
np.nan} for beta-delta discounting.
Returns:
dict: A nested dictionary of model parameters.
ParamsDict: A nested dictionary of model parameters.
"""
if default_params is None:
Expand All @@ -51,7 +51,7 @@ def create_params_template(
return default_params | function_params | stochastic_transition_params


def _create_function_params(user_model: Model) -> Params:
def _create_function_params(user_model: Model) -> dict[str, dict[str, float]]:
"""Get function parameters from a model specification.
Explanation: We consider the arguments of all model functions, from which we exclude
Expand Down Expand Up @@ -79,7 +79,7 @@ def _create_function_params(user_model: Model) -> Params:
if hasattr(user_model, "shocks"):
variables = variables | set(user_model.shocks)

function_params: Params = {}
function_params = {}
# For each model function, capture the arguments of the function that are not in the
# set of model variables, and initialize them.
for name, func in user_model.functions.items():
Expand Down
6 changes: 3 additions & 3 deletions src/lcm/process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from lcm.interfaces import InternalModel
from lcm.model import Model
from lcm.typing import Params
from lcm.typing import ParamsDict


def process_model(user_model: Model) -> InternalModel:
Expand Down Expand Up @@ -247,7 +247,7 @@ def _get_functions(
function_info: pd.DataFrame,
variable_info: pd.DataFrame,
grids: dict[str, Array],
params: Params,
params: ParamsDict,
) -> dict[str, Callable]:
"""Process the user provided model functions.
Expand Down Expand Up @@ -327,7 +327,7 @@ def _get_functions(


def _replace_func_parameters_by_params(
func: Callable, params: Params, name: str
func: Callable, params: ParamsDict, name: str
) -> Callable:
old_signature = list(inspect.signature(func).parameters)
new_kwargs = [
Expand Down
11 changes: 2 additions & 9 deletions src/lcm/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, Literal, TypedDict
from typing import Annotated, Any, Literal, TypedDict

from jax import Array

Expand All @@ -10,14 +10,7 @@

DiscreteLabels = Annotated[list[int], "Int range starting from 0 with increments of 1"]

# Parameters in LCM are made out of three categories: (1) the default parameters
# required for the model class. They appear as dict[str, int | float]. (2) the
# parameters corresponding to the user model functions. They appear as
# dict[str, dict[str, int | float]], where for each user function the parameters
# for this function are stored in a dict. (3) the parameters corresponding to the
# the stochastic transitions. They appear as dict[str, dict[str, Array]],
# where for each stochastic state variable the transition matrix is stored as an Array.
Params = dict[str, int | float | dict[str, int | float] | dict[str, Array]]
ParamsDict = dict[str, Any]


class SegmentInfo(TypedDict):
Expand Down

0 comments on commit 048ce87

Please sign in to comment.