Skip to content

Commit

Permalink
Add tests and fix simulation with non-index discrete grids
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 20, 2024
1 parent 10180f0 commit 2cf15b0
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 41 deletions.
83 changes: 52 additions & 31 deletions src/lcm/input_processing/discrete_grid_conversion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import functools
from collections.abc import Callable
from dataclasses import dataclass, field, make_dataclass
from typing import TypeVar, cast

import jax.numpy as jnp
from dags.signature import with_signature
Expand Down Expand Up @@ -70,49 +72,36 @@ def params_to_internal(self, params: ParamsDict) -> ParamsDict:
out[f"next___{var}_index__"] = out.pop(old_name)
return out

def internal_to_states(self, states: dict[str, Array]) -> dict[str, Array]:
def internal_to_discrete_vars(
self, variables: dict[str, Array]
) -> dict[str, Array]:
"""Convert states from internal to external representation.
If a state has been converted, the name of its corresponding index function must
be changed from `___{var}_index__` to `{var}`, and the values of the state must
be converted from indices to codes.
If a variable has been converted, the name of its corresponding index function
must be changed from `___{var}_index__` to `{var}`, and the values of the
variable must be converted from indices to codes.
"""
out = states.copy()
out = variables.copy()
for var, index_to_code in self.index_to_code.items():
old_name = f"__{var}_index__"
if old_name in states:
if old_name in variables:
out[var] = index_to_code(out.pop(old_name))
return out

def states_to_internal(self, states: dict[str, Array]) -> dict[str, Array]:
def discrete_vars_to_internal(
self, variables: dict[str, Array]
) -> dict[str, Array]:
"""Convert states from external to internal representation.
If a state has been converted, the name of its corresponding index function must
be changed from `{var}` to `___{var}_index__`, and the values of the state must
be converted from codes to indices.
If a variable has been converted, the name of its corresponding index function
must be changed from `{var}` to `___{var}_index__`, and the values of the
variable must be converted from codes to indices.
"""
out = states.copy()
out = variables.copy()
for var, code_to_index in self.code_to_index.items():
if var in states:
out[f"__{var}_index__"] = code_to_index(out.pop(var))
return out

def internal_to_choices(self, choices: dict[str, Array]) -> dict[str, Array]:
"""Convert choices from internal to external representation."""
out = choices.copy()
for var, index_to_code in self.index_to_code.items():
old_name = f"__{var}_index__"
if old_name in choices:
out[var] = index_to_code(out.pop(old_name))
return out

def choices_to_internal(self, choices: dict[str, Array]) -> dict[str, Array]:
"""Convert choices from external to internal representation."""
out = choices.copy()
for var, code_to_index in self.code_to_index.items():
if var in choices:
if var in variables:
out[f"__{var}_index__"] = code_to_index(out.pop(var))
return out

Expand Down Expand Up @@ -172,7 +161,11 @@ def convert_arbitrary_codes_to_array_indices(
non_index_states = [s for s in model.states if s in non_index_discrete_vars]

for var in non_index_states:
functions[f"next___{var}_index__"] = functions.pop(f"next_{var}")
functions[f"next___{var}_index__"] = _get_next_index_func(
functions.pop(f"next_{var}"),
codes_array=gridspecs[var].to_jax(),
name=var,
)

# Add index to code functions
# ----------------------------------------------------------------------------------
Expand Down Expand Up @@ -218,6 +211,31 @@ def _get_discrete_vars_with_non_index_codes(model: Model) -> list[str]:
return discrete_vars


F = TypeVar("F", bound=Callable[[tuple[Array, ...]], Array])


def _get_next_index_func(next_func: F, codes_array: Array, name: str) -> F:
"""Get next function for index state variable.
Args:
next_func: The next function for the state variable.
codes_array: An array of codes.
name: The name of the state variable.
Returns:
A next function corresponding to the index version of the state variable.
"""
code_to_index = _get_code_to_index_func(codes_array, name)

@functools.wraps(next_func)
def next_index_func(*args, **kwargs) -> Array:
next_state = next_func(*args, **kwargs)
return code_to_index(next_state)

return cast(F, next_index_func)


def _get_index_to_code_func(codes_array: Array, name: str) -> Callable[[Array], Array]:
"""Get function mapping from index to code.
Expand Down Expand Up @@ -253,11 +271,14 @@ def _get_code_to_index_func(codes_array: Array, name: str) -> Callable[[Array],
indices.
"""
sorted_indices = jnp.argsort(codes_array)
sorted_codes = codes_array[sorted_indices]

@with_signature(args=[name])
def code_to_index(*args, **kwargs):
kwargs = all_as_kwargs(args, kwargs, arg_names=[name])
data = kwargs[name]
return jnp.argmax(data[:, None] == codes_array[None, :], axis=1)
indices_in_sorted = jnp.searchsorted(sorted_codes, data)
return sorted_indices[indices_in_sorted]

return code_to_index
10 changes: 6 additions & 4 deletions src/lcm/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def simulate(
sparse_choice_variables = model.variable_info.query("is_choice & is_sparse").index

# The following variables are updated during the forward simulation
states = discrete_grid_converter.states_to_internal(initial_states)
states = discrete_grid_converter.discrete_vars_to_internal(initial_states)
key = jax.random.PRNGKey(seed=seed)

# Forward simulation
Expand Down Expand Up @@ -213,7 +213,7 @@ def simulate(

logger.info("Period: %s", period)

processed = _process_simulated_data(_simulation_results)
processed = _process_simulated_data(_simulation_results, discrete_grid_converter)

if additional_targets is not None:
calculated_targets = _compute_targets(
Expand Down Expand Up @@ -339,7 +339,7 @@ def _compute_targets(processed_results, targets, model_functions, params):
return target_func(params=params, **kwargs)


def _process_simulated_data(results):
def _process_simulated_data(results, discrete_grid_converter):
"""Process and flatten the simulation results.
This function produces a dict of arrays for each var with dimension (n_periods *
Expand All @@ -352,6 +352,7 @@ def _process_simulated_data(results):
results (list): List of dicts with simulation results. Each dict contains the
value, choices, and states for one period. Choices and states are stored in
a nested dictionary.
discrete_grid_converter (DiscreteGridConverter): Converter for discrete grids.
Returns:
dict: Dict with processed simulation results. The keys are the variable names
Expand All @@ -370,7 +371,8 @@ def _process_simulated_data(results):
}
out = {key: jnp.concatenate(values) for key, values in dict_of_lists.items()}
out["_period"] = jnp.repeat(jnp.arange(n_periods), n_initial_states)
return out

return discrete_grid_converter.internal_to_discrete_vars(out)


# ======================================================================================
Expand Down
16 changes: 12 additions & 4 deletions tests/input_processing/test_discrete_state_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,30 @@ def test_discrete_state_converter_params_to_internal(discrete_state_converter_kw
assert converter.params_to_internal(params) == expected


def test_discrete_state_converter_internal_to_states(discrete_state_converter_kwargs):
def test_discrete_state_converter_internal_to_discrete_vars(
discrete_state_converter_kwargs,
):
expected = jnp.array([1, 0])
internal_states = {
"__c_index__": jnp.array([0, 1]),
}
converter = DiscreteGridConverter(**discrete_state_converter_kwargs)
assert_array_equal(converter.internal_to_states(internal_states)["c"], expected)
assert_array_equal(
converter.internal_to_discrete_vars(internal_states)["c"], expected
)


def test_discrete_state_converter_states_to_internal(discrete_state_converter_kwargs):
def test_discrete_state_converter_discrete_vars_to_internal(
discrete_state_converter_kwargs,
):
expected = jnp.array([0, 1])
states = {
"c": jnp.array([1, 0]),
}
converter = DiscreteGridConverter(**discrete_state_converter_kwargs)
assert_array_equal(converter.states_to_internal(states)["__c_index__"], expected)
assert_array_equal(
converter.discrete_vars_to_internal(states)["__c_index__"], expected
)


def test_discrete_state_converter_raises_error_if_keys_dont_match():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def next_wealth(wealth, consumption, labor_income, interest_rate):
# the grid. We therefore round the result of the continuous state transition function.
def next_wealth_discrete(wealth, consumption, labor_income, interest_rate):
next_wealth_cont = next_wealth(wealth, consumption, labor_income, interest_rate)
return jnp.rint(next_wealth_cont).astype(jnp.int32)
return jnp.clip(jnp.rint(next_wealth_cont).astype(jnp.int32), 1, 400)


# --------------------------------------------------------------------------------------
Expand Down
30 changes: 29 additions & 1 deletion tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,32 @@ def test_simulate_using_get_lcm_function(
assert (res.loc[period]["value"].diff()[1:] >= 0).all()


# ======================================================================================
# Test simulation works correctly with discrete grid conversion
# ======================================================================================


def test_simulate_with_discrete_grid_conversion():
"""Test that the simulation works correctly with discrete grid conversion."""
model = get_model_config("iskhakov_et_al_2017_fully_discrete", n_periods=2)
params = get_params(wage=2)

simulate_model, _ = get_lcm_function(model=model, targets="solve_and_simulate")

res = simulate_model(
params,
initial_states={"wealth": jnp.array([1, 4])},
additional_targets=["labor_income", "working"],
)

assert "__consumption_index__" not in res.columns
assert "__wealth_index__" not in res.columns

assert_array_equal(res["retirement"], jnp.array([0, 1, 1, 1]))
assert_array_equal(res["consumption"], jnp.array([1, 2, 2, 2]))
assert_array_equal(res["wealth"], jnp.array([1, 4, 2, 2]))


# ======================================================================================
# Testing effects of parameters
# ======================================================================================
Expand Down Expand Up @@ -354,7 +380,9 @@ def test_process_simulated_data():
"b": jnp.array([-1, -2, -3, -4]),
}

got = _process_simulated_data(simulated)
got = _process_simulated_data(
simulated, discrete_grid_converter=DiscreteGridConverter()
)
assert tree_equal(expected, got)


Expand Down

0 comments on commit 2cf15b0

Please sign in to comment.