diff --git a/environment.yml b/environment.yml index f283dc5a..e58fe533 100644 --- a/environment.yml +++ b/environment.yml @@ -34,6 +34,7 @@ dependencies: - pre-commit - pydot - snakeviz + - memory_profiler # Documentation - sphinx diff --git a/src/lcm/example_models/example_models_long.py b/src/lcm/example_models/example_models_long.py index ff281f28..ce675035 100644 --- a/src/lcm/example_models/example_models_long.py +++ b/src/lcm/example_models/example_models_long.py @@ -2,21 +2,12 @@ import jax.numpy as jnp RETIREMENT_AGE = 65 -N_CHOICE_GRID_POINTS = 5_000 -N_STATE_GRID_POINTS = 5_000 +N_CHOICE_GRID_POINTS = 200 +N_STATE_GRID_POINTS = 100 -def phelps_deaton_utility(consumption, working, delta): - return jnp.log(consumption) - delta * working - - -def phelps_deaton_utility_with_filter( - consumption, - working, - delta, - lagged_retirement, # noqa: ARG001 -): - return jnp.log(consumption) - delta * working +def phelps_deaton_utility(consumption, working, health, sport, delta): + return jnp.log(consumption + 1) - (delta - health) * working - sport def working(retirement): @@ -38,6 +29,10 @@ def next_wealth(wealth, consumption, working, wage, interest_rate): return (1 + interest_rate) * (wealth - consumption) + wage * working +def next_health(health, sport, working): + return health * (1 + sport - working / 2) + + def consumption_constraint(consumption, wealth): return consumption <= wealth @@ -58,6 +53,7 @@ def age(_period): "working": working, "wage": wage, "age": age, + "next_health": next_health, }, "choices": { "retirement": {"options": [0, 1]}, @@ -67,6 +63,12 @@ def age(_period): "stop": 100, "n_points": N_CHOICE_GRID_POINTS, }, + "sport": { + "grid_type": "linspace", + "start": 0, + "stop": 1, + "n_points": N_CHOICE_GRID_POINTS, + }, }, "states": { "wealth": { @@ -78,11 +80,15 @@ def age(_period): "health": { "grid_type": "linspace", "start": 0, - "stop": 100, + "stop": 1, "n_points": N_STATE_GRID_POINTS, }, }, "n_periods": RETIREMENT_AGE - 18, } -PARAMS = {} +PARAMS = { + "beta": 0.95, + "utility": {"delta": 0.05}, + "next_wealth": {"interest_rate": 0.05}, +} diff --git a/tests/test_long.py b/tests/test_long.py new file mode 100644 index 00000000..af7dfe77 --- /dev/null +++ b/tests/test_long.py @@ -0,0 +1,13 @@ +import pytest +from lcm.entry_point import get_lcm_function +from lcm.example_models.example_models_long import PARAMS, PHELPS_DEATON + +SKIP_REASON = """The test is designed to run approximately 1 minute on a standard +laptop, such that we can differentiate the performance of running LCM on a GPU versus +on the CPU.""" + + +@pytest.skip(reason=SKIP_REASON) +def test_long(): + solve_model, template = get_lcm_function(PHELPS_DEATON, targets="solve") + solve_model(PARAMS)