Skip to content

Commit

Permalink
Integrate comments from review
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 23, 2024
1 parent 420eedd commit 55dad82
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 34 deletions.
52 changes: 33 additions & 19 deletions tests/test_entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@
from lcm.input_processing import process_model
from lcm.model_functions import get_utility_and_feasibility_function
from lcm.state_space import create_state_choice_space
from tests.test_models.deterministic import get_model_config
from tests.test_models.deterministic import (
DiscreteConsumptionChoice,
RetirementStatus,
get_model_config,
)
from tests.test_models.deterministic import utility as iskhakov_et_al_2017_utility

# ======================================================================================
# Test cases
# ======================================================================================


STRIPPED_DOWN_AND_FULLY_DISCRETE_MODELS = [
STRIPPED_DOWN_AND_DISCRETE_MODELS = [
"iskhakov_et_al_2017_stripped_down",
"iskhakov_et_al_2017_discrete",
]
Expand Down Expand Up @@ -90,11 +94,8 @@ def test_get_lcm_function_with_simulation_target_simple_fully_discrete():

@pytest.mark.parametrize(
"model",
[
get_model_config(name, n_periods=3)
for name in STRIPPED_DOWN_AND_FULLY_DISCRETE_MODELS
],
ids=STRIPPED_DOWN_AND_FULLY_DISCRETE_MODELS,
[get_model_config(name, n_periods=3) for name in STRIPPED_DOWN_AND_DISCRETE_MODELS],
ids=STRIPPED_DOWN_AND_DISCRETE_MODELS,
)
def test_get_lcm_function_with_simulation_is_coherent(model):
"""Test that solve_and_simulate creates same output as solve then simulate."""
Expand Down Expand Up @@ -153,7 +154,13 @@ def test_get_lcm_function_with_simulation_target_iskhakov_et_al_2017(model):
vf_arr_list=vf_arr_list,
initial_states={
"wealth": jnp.array([10.0, 10.0, 20.0]),
"lagged_retirement": jnp.array([0, 1, 1]),
"lagged_retirement": jnp.array(
[
RetirementStatus.working,
RetirementStatus.retired,
RetirementStatus.retired,
]
),
},
)

Expand Down Expand Up @@ -200,14 +207,14 @@ def test_create_compute_conditional_continuation_value():

val = compute_ccv(
consumption=jnp.array([10, 20, 30.0]),
retirement=1,
retirement=RetirementStatus.retired,
wealth=30,
params=params,
vf_arr=None,
)
assert val == iskhakov_et_al_2017_utility(
consumption=30.0,
working=0,
working=RetirementStatus.working,
disutility_of_work=1.0,
)

Expand Down Expand Up @@ -248,15 +255,17 @@ def test_create_compute_conditional_continuation_value_with_discrete_model():
)

val = compute_ccv(
consumption=jnp.array([0, 1]),
retirement=1,
consumption=jnp.array(
[DiscreteConsumptionChoice.low, DiscreteConsumptionChoice.high]
),
retirement=RetirementStatus.retired,
wealth=2,
params=params,
vf_arr=None,
)
assert val == iskhakov_et_al_2017_utility(
consumption=2,
working=0,
working=RetirementStatus.working,
disutility_of_work=1.0,
)

Expand Down Expand Up @@ -303,15 +312,15 @@ def test_create_compute_conditional_continuation_policy():

policy, val = compute_ccv_policy(
consumption=jnp.array([10, 20, 30.0]),
retirement=1,
retirement=RetirementStatus.retired,
wealth=30,
params=params,
vf_arr=None,
)
assert policy == 2
assert val == iskhakov_et_al_2017_utility(
consumption=30.0,
working=0,
working=RetirementStatus.working,
disutility_of_work=1.0,
)

Expand Down Expand Up @@ -352,16 +361,18 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model():
)

policy, val = compute_ccv_policy(
consumption=jnp.array([0, 1]),
retirement=1,
consumption=jnp.array(
[DiscreteConsumptionChoice.low, DiscreteConsumptionChoice.high]
),
retirement=RetirementStatus.retired,
wealth=2,
params=params,
vf_arr=None,
)
assert policy == 1
assert val == iskhakov_et_al_2017_utility(
consumption=2,
working=0,
working=RetirementStatus.working,
disutility_of_work=1.0,
)

Expand All @@ -375,7 +386,10 @@ def test_get_lcm_function_with_period_argument_in_filter():
model = get_model_config("iskhakov_et_al_2017", n_periods=3)

def absorbing_retirement_filter(retirement, lagged_retirement, _period):
return jnp.logical_or(retirement == 1, lagged_retirement == 0)
return jnp.logical_or(
retirement == RetirementStatus.retired,
lagged_retirement == RetirementStatus.working,
)

model.functions["absorbing_retirement_filter"] = absorbing_retirement_filter

Expand Down
40 changes: 26 additions & 14 deletions tests/test_models/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ class RetirementStatus:
retired: int = 1


@dataclass
class DiscreteConsumptionChoice:
low: int = 0
high: int = 1


@dataclass
class DiscreteWealthLevels:
low: int = 0
medium: int = 1
high: int = 2


# --------------------------------------------------------------------------------------
# Utility functions
# --------------------------------------------------------------------------------------
Expand All @@ -51,7 +64,7 @@ def utility_with_filter(
def utility_discrete(consumption, working, disutility_of_work):
# In the discrete model, consumption is defined as "low" or "high". This can be
# translated to the levels 1 and 2.
consumption_level = 1 + (consumption == ConsumptionStatus.high)
consumption_level = 1 + (consumption == DiscreteConsumptionChoice.high)
return utility(consumption_level, working, disutility_of_work)


Expand Down Expand Up @@ -81,6 +94,15 @@ def next_wealth(wealth, consumption, labor_income, interest_rate):
return (1 + interest_rate) * (wealth - consumption) + labor_income


def next_wealth_discrete(wealth, consumption, labor_income, interest_rate):
# For discrete state variables, we need to assure that the next state is also a
# valid state, i.e., it is a member of the discrete grid.
continuous = next_wealth(wealth, consumption, labor_income, interest_rate)
return jnp.clip(
jnp.rint(continuous), DiscreteWealthLevels.low, DiscreteWealthLevels.high
).astype(jnp.int32)


# --------------------------------------------------------------------------------------
# Constraints
# --------------------------------------------------------------------------------------
Expand Down Expand Up @@ -170,12 +192,6 @@ def absorbing_retirement_filter(retirement, lagged_retirement):
)


@dataclass
class ConsumptionStatus:
low: int = 0
high: int = 1


ISKHAKOV_ET_AL_2017_DISCRETE = Model(
description=(
"Starts from Iskhakov et al. (2017), removes filters and the lagged_retirement "
Expand All @@ -184,21 +200,17 @@ class ConsumptionStatus:
n_periods=3,
functions={
"utility": utility_discrete,
"next_wealth": next_wealth,
"next_wealth": next_wealth_discrete,
"consumption_constraint": consumption_constraint,
"labor_income": labor_income,
"working": working,
},
choices={
"retirement": DiscreteGrid(RetirementStatus),
"consumption": DiscreteGrid(ConsumptionStatus),
"consumption": DiscreteGrid(DiscreteConsumptionChoice),
},
states={
"wealth": LinspaceGrid(
start=0,
stop=400,
n_points=100,
),
"wealth": DiscreteGrid(DiscreteWealthLevels),
},
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_simulate_with_only_discrete_choices():

assert_array_equal(res["retirement"], jnp.array([0, 1, 1, 1]))
assert_array_equal(res["consumption"], jnp.array([0, 1, 1, 1]))
assert_array_equal(res["wealth"], jnp.array([0, 4, 1.5, 3]))
assert_array_equal(res["wealth"], jnp.array([0, 4, 2, 2]))


# ======================================================================================
Expand Down

0 comments on commit 55dad82

Please sign in to comment.