Skip to content

Commit

Permalink
Start refactoring input processing
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 10, 2024
1 parent 6603e02 commit 357e29e
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 60 deletions.
2 changes: 0 additions & 2 deletions src/lcm/discrete_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def get_solve_discrete_problem(
variable_info: pd.DataFrame,
is_last_period: bool,
choice_segments: SegmentInfo | None,
params: ParamsDict,
) -> Callable[[Array], Array]:
"""Get function that computes the expected max. of conditional continuation values.
Expand Down Expand Up @@ -72,7 +71,6 @@ def get_solve_discrete_problem(
func,
choice_axes=choice_axes,
choice_segments=choice_segments,
params=params,
)


Expand Down
1 change: 0 additions & 1 deletion src/lcm/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def get_lcm_function(
variable_info=_mod.variable_info,
is_last_period=is_last_period,
choice_segments=choice_segments[period],
params=_mod.params,
)
emax_calculators.append(calculator)

Expand Down
43 changes: 13 additions & 30 deletions src/lcm/input_processing/process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,9 @@ def process_model(user_model: Model) -> InternalModel:
"""
function_info = _get_function_info(user_model)

variable_info = _get_variable_info(
user_model,
function_info=function_info,
)

gridspecs = _get_gridspecs(user_model, variable_info=variable_info)

grids = _get_grids(gridspecs=gridspecs, variable_info=variable_info)
variable_info = _get_variable_info(user_model)
gridspecs = _get_gridspecs(user_model)
grids = _get_grids(user_model)

params = create_params_template(
user_model,
Expand Down Expand Up @@ -97,15 +91,11 @@ def _get_function_info(user_model: Model) -> pd.DataFrame:
return info


def _get_variable_info(user_model: Model, function_info: pd.DataFrame) -> pd.DataFrame:
def _get_variable_info(user_model: Model) -> pd.DataFrame:
"""Derive information about all variables in the model.
Args:
user_model: The model as provided by the user.
function_info: A table with information about all functions in the model. The
index contains the name of a function. The columns are booleans that are
True if the function has the corresponding property. The columns are:
is_filter, is_constraint, is_next, is_stochastic_next.
Returns:
pd.DataFrame: A table with information about all variables in the model. The
Expand All @@ -114,6 +104,8 @@ def _get_variable_info(user_model: Model, function_info: pd.DataFrame) -> pd.Dat
is_state, is_choice, is_continuous, is_discrete, is_sparse, is_dense.
"""
function_info = _get_function_info(user_model)

variables = user_model.states | user_model.choices

info = pd.DataFrame(index=list(variables))
Expand Down Expand Up @@ -193,17 +185,11 @@ def _get_auxiliary_variables(

def _get_gridspecs(
user_model: Model,
variable_info: pd.DataFrame,
) -> dict[str, Grid]:
"""Create a dictionary of grid specifications for each variable in the model.
Args:
user_model (dict): The model as provided by the user.
variable_info (pandas.DataFrame): A table with information about all
variables in the model. The index contains the name of a model variable.
The columns are booleans that are True if the variable has the
corresponding property. The columns are: is_state, is_choice, is_continuous,
is_discrete, is_sparse, is_dense.
Returns:
dict: Dictionary containing all variables of the model. The keys are
Expand All @@ -212,32 +198,29 @@ def _get_gridspecs(
variables this is information about how to build the grids.
"""
variable_info = _get_variable_info(user_model)

raw_variables = user_model.states | user_model.choices
order = variable_info.index.tolist()
return {k: raw_variables[k] for k in order}


def _get_grids(
gridspecs: dict[str, Grid],
variable_info: pd.DataFrame,
user_model: Model,
) -> dict[str, Array]:
"""Create a dictionary of array grids for each variable in the model.
Args:
gridspecs: Dictionary containing all variables of the model. The keys are the
names of the variables. The values describe which values the variable can
take. For discrete variables these are the options (jnp.array). For
continuous variables this is information about how to build the grids.
variable_info: A table with information about all variables in the model. The
index contains the name of a model variable. The columns are booleans that
are True if the variable has the corresponding property. The columns are:
is_state, is_choice, is_continuous, is_discrete, is_sparse, is_dense.
user_model: The model as provided by the user.
Returns:
dict: Dictionary containing all variables of the model. The keys are
the names of the variables. The values are the grids.
"""
variable_info = _get_variable_info(user_model)
gridspecs = _get_gridspecs(user_model)

grids = {name: spec.to_jax() for name, spec in gridspecs.items()}
order = variable_info.index.tolist()
return {k: grids[k] for k in order}
Expand Down
2 changes: 1 addition & 1 deletion src/lcm/solve_brute.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def solve(

# solve discrete problem by calculating expected maximum over discrete choices
calculate_emax = emax_calculators[period]
vf_arr = calculate_emax(conditional_continuation_values)
vf_arr = calculate_emax(conditional_continuation_values, params=params)
reversed_solution.append(vf_arr)

logger.info("Period: %s", period)
Expand Down
25 changes: 4 additions & 21 deletions tests/input_processing/test_process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,7 @@ def test_get_function_info(user_model):


def test_get_variable_info(user_model):
function_info = _get_function_info(user_model)
got = _get_variable_info(
user_model,
function_info,
)
got = _get_variable_info(user_model)
exp = pd.DataFrame(
{
"is_state": [False, True],
Expand All @@ -94,22 +90,13 @@ def test_get_variable_info(user_model):


def test_get_gridspecs(user_model):
variable_info = _get_variable_info(
user_model,
function_info=_get_function_info(user_model),
)
got = _get_gridspecs(user_model, variable_info)
got = _get_gridspecs(user_model)
assert got["a"] == DiscreteGrid([0, 1])
assert got["c"] == DiscreteGrid([0, 1])


def test_get_grids(user_model):
variable_info = _get_variable_info(
user_model,
function_info=_get_function_info(user_model),
)
gridspecs = _get_gridspecs(user_model, variable_info)
got = _get_grids(gridspecs, variable_info)
got = _get_grids(user_model)
assert_array_equal(got["a"], jnp.array([0, 1]))
assert_array_equal(got["c"], jnp.array([0, 1]))

Expand Down Expand Up @@ -284,9 +271,5 @@ def wealth_filter(wealth):

user_model.functions["wealth_filter"] = wealth_filter

function_info = _get_function_info(user_model)
got = _get_variable_info(
user_model,
function_info,
)
got = _get_variable_info(user_model)
assert got.index.is_unique
6 changes: 2 additions & 4 deletions tests/test_discrete_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@ def test_aggregation_without_shocks(cc_values, segment_info, collapse, n_extra_a
variable_info=var_info,
is_last_period=False,
choice_segments=segment_info,
params={},
)

calculated = solve_discrete_problem(cc_values)
calculated = solve_discrete_problem(cc_values, params=None)

expected = jnp.array([8, 9.5])

Expand Down Expand Up @@ -159,7 +158,6 @@ def test_get_solve_discrete_problem_illustrative():
variable_info=variable_info,
is_last_period=False,
choice_segments=None,
params=None,
)

cc_values = jnp.array(
Expand All @@ -170,7 +168,7 @@ def test_get_solve_discrete_problem_illustrative():
],
)

got = solve_discrete_problem(cc_values)
got = solve_discrete_problem(cc_values, params=None)
aaae(got, jnp.array([1, 3, 5]))


Expand Down
2 changes: 1 addition & 1 deletion tests/test_solve_brute.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _get_continuation_value(lazy, wealth, vf_arr):
# create emax aggregators and choice segments
# ==================================================================================

def calculate_emax(values):
def calculate_emax(values, params): # noqa: ARG001
"""Take max over axis that corresponds to working."""
return values.max(axis=1)

Expand Down

0 comments on commit 357e29e

Please sign in to comment.