Skip to content

Commit

Permalink
Add test for long running example model
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Feb 7, 2024
1 parent 407fb4a commit 01a4461
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 15 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies:
- pre-commit
- pydot
- snakeviz
- memory_profiler

# Documentation
- sphinx
Expand Down
36 changes: 21 additions & 15 deletions src/lcm/example_models/example_models_long.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,12 @@
import jax.numpy as jnp

RETIREMENT_AGE = 65
N_CHOICE_GRID_POINTS = 5_000
N_STATE_GRID_POINTS = 5_000
N_CHOICE_GRID_POINTS = 200
N_STATE_GRID_POINTS = 100


def phelps_deaton_utility(consumption, working, delta):
return jnp.log(consumption) - delta * working


def phelps_deaton_utility_with_filter(
consumption,
working,
delta,
lagged_retirement, # noqa: ARG001
):
return jnp.log(consumption) - delta * working
def phelps_deaton_utility(consumption, working, health, sport, delta):
return jnp.log(consumption + 1) - (delta - health) * working - sport


def working(retirement):
Expand All @@ -38,6 +29,10 @@ def next_wealth(wealth, consumption, working, wage, interest_rate):
return (1 + interest_rate) * (wealth - consumption) + wage * working


def next_health(health, sport, working):
return health * (1 + sport - working / 2)


def consumption_constraint(consumption, wealth):
return consumption <= wealth

Expand All @@ -58,6 +53,7 @@ def age(_period):
"working": working,
"wage": wage,
"age": age,
"next_health": next_health,
},
"choices": {
"retirement": {"options": [0, 1]},
Expand All @@ -67,6 +63,12 @@ def age(_period):
"stop": 100,
"n_points": N_CHOICE_GRID_POINTS,
},
"sport": {
"grid_type": "linspace",
"start": 0,
"stop": 1,
"n_points": N_CHOICE_GRID_POINTS,
},
},
"states": {
"wealth": {
Expand All @@ -78,11 +80,15 @@ def age(_period):
"health": {
"grid_type": "linspace",
"start": 0,
"stop": 100,
"stop": 1,
"n_points": N_STATE_GRID_POINTS,
},
},
"n_periods": RETIREMENT_AGE - 18,
}

PARAMS = {}
PARAMS = {
"beta": 0.95,
"utility": {"delta": 0.05},
"next_wealth": {"interest_rate": 0.05},
}
13 changes: 13 additions & 0 deletions tests/test_long.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest
from lcm.entry_point import get_lcm_function
from lcm.example_models.example_models_long import PARAMS, PHELPS_DEATON

SKIP_REASON = """The test is designed to run approximately 1 minute on a standard
laptop, such that we can differentiate the performance of running LCM on a GPU versus
on the CPU."""


@pytest.skip(reason=SKIP_REASON)
def test_long():
solve_model, template = get_lcm_function(PHELPS_DEATON, targets="solve")
solve_model(PARAMS)

0 comments on commit 01a4461

Please sign in to comment.