Skip to content

Commit

Permalink
Refactor example models
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Feb 28, 2024
1 parent 4f97784 commit 39a271d
Show file tree
Hide file tree
Showing 19 changed files with 271 additions and 211 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ repos:
- id: name-tests-test
args:
- --pytest-test-first
exclude: ^tests/test_models/
- id: no-commit-to-branch
args:
- --branch
Expand Down
5 changes: 3 additions & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ exclude *.yml
exclude *.pickle
exclude pytask.ini

prune src/lcm/sandbox
prune .envs
prune examples
prune docs
prune src/lcm/sandbox
prune tests
prune .envs

global-exclude __pycache__
global-exclude *.py[co]
Expand Down
22 changes: 22 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Example model specifications

## Choosing an example

| Example name | Description | Runtime |
| -------------- | -------------------------------------------------- | ------------- |
| `long_running` | Consumptions-savings model with health and leisure | a few minutes |

## Running an example

Say you want to run the [`long_running`](./long_running.py) example locally. In a Python
shell, execute:

```python
from lcm.entry_point import get_lcm_function

from long_running import MODEL_CONFIG, PARAMS


solve_model, _ = get_lcm_function(model=MODEL_CONFIG)
vf_arr = solve_model(PARAMS)
```
121 changes: 121 additions & 0 deletions examples/long_running.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Example specification for a consumption-savings model with health and leisure."""

import jax.numpy as jnp

# ======================================================================================
# Numerical parameters and constants
# ======================================================================================
N_GRID_POINTS = {
"states": 100,
"choices": 200,
}

RETIREMENT_AGE = 65

# ======================================================================================
# Model functions
# ======================================================================================


# --------------------------------------------------------------------------------------
# Utility function
# --------------------------------------------------------------------------------------
def utility(consumption, working, health, sport, delta):
return jnp.log(consumption) - (delta - health) * working - sport


# --------------------------------------------------------------------------------------
# Auxiliary variables
# --------------------------------------------------------------------------------------
def working(leisure):
return 1 - leisure


def wage(age):
return 1 + 0.1 * age


def age(_period):
return _period + 18


# --------------------------------------------------------------------------------------
# State transitions
# --------------------------------------------------------------------------------------
def next_wealth(wealth, consumption, working, wage, interest_rate):
return (1 + interest_rate) * (wealth - consumption) + wage * working


def next_health(health, sport, working):
return health * (1 + sport - working / 2)


def next_wealth_with_shock(
wealth,
consumption,
working,
wage,
wage_shock,
interest_rate,
):
return interest_rate * (wealth - consumption) + wage * wage_shock * working


# --------------------------------------------------------------------------------------
# Constraints
# --------------------------------------------------------------------------------------
def consumption_constraint(consumption, wealth):
return consumption <= wealth


# ======================================================================================
# Model specification and parameters
# ======================================================================================

MODEL_CONFIG = {
"functions": {
"utility": utility,
"next_wealth": next_wealth,
"consumption_constraint": consumption_constraint,
"working": working,
"wage": wage,
"age": age,
"next_health": next_health,
},
"choices": {
"leisure": {"options": [0, 1]},
"consumption": {
"grid_type": "linspace",
"start": 1,
"stop": 100,
"n_points": N_GRID_POINTS["choices"],
},
"sport": {
"grid_type": "linspace",
"start": 0,
"stop": 1,
"n_points": N_GRID_POINTS["choices"],
},
},
"states": {
"wealth": {
"grid_type": "linspace",
"start": 1,
"stop": 100,
"n_points": N_GRID_POINTS["states"],
},
"health": {
"grid_type": "linspace",
"start": 0,
"stop": 1,
"n_points": N_GRID_POINTS["states"],
},
},
"n_periods": RETIREMENT_AGE - 18,
}

PARAMS = {
"beta": 0.95,
"utility": {"delta": 0.05},
"next_wealth": {"interest_rate": 0.05},
}
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ write_to = "src/lcm/_version.py"
[tool.ruff]
target-version = "py311"
fix = true

[tool.ruff.lint]
select = ["ALL"]
extend-ignore = [
# missing type annotation
Expand Down Expand Up @@ -57,6 +59,7 @@ extend-ignore = [
[tool.ruff.lint.per-file-ignores]
"docs/source/conf.py" = ["E501", "ERA001", "DTZ005"]
"tests/test_*.py" = ["PLR2004"]
"examples/*" = ["INP001"]

[tool.ruff.lint.pydocstyle]
convention = "google"
Expand All @@ -75,7 +78,7 @@ markers = [
"slow: Tests that take a long time to run and are skipped in continuous integration.",
"illustrative: Tests are designed for illustrative purposes",
]
norecursedirs = ["docs", ".envs"]
norecursedirs = ["docs", ".envs", "tests/test_models"]


[tool.yamlfix]
Expand Down
95 changes: 0 additions & 95 deletions src/lcm/example_models/testing_example_models.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/lcm/get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pybaum import tree_update

from lcm.example_models.basic_example_models import (
from tests.test_models.phelps_deaton import (
PHELPS_DEATON,
PHELPS_DEATON_WITH_FILTERS,
)
Expand Down
8 changes: 0 additions & 8 deletions src/lcm/test.py

This file was deleted.

21 changes: 11 additions & 10 deletions tests/test_entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
create_compute_conditional_continuation_value,
get_lcm_function,
)
from lcm.example_models.basic_example_models import (
PHELPS_DEATON,
PHELPS_DEATON_FULLY_DISCRETE,
PHELPS_DEATON_WITH_FILTERS,
phelps_deaton_utility,
)
from lcm.model_functions import get_utility_and_feasibility_function
from lcm.process_model import process_model
from lcm.state_space import create_state_choice_space
from pybaum import tree_equal, tree_map

from tests.test_models.phelps_deaton import (
PHELPS_DEATON,
PHELPS_DEATON_FULLY_DISCRETE,
PHELPS_DEATON_WITH_FILTERS,
utility,
)

MODELS = {
"simple": PHELPS_DEATON,
"with_filters": PHELPS_DEATON_WITH_FILTERS,
Expand Down Expand Up @@ -175,7 +176,7 @@ def test_create_compute_conditional_continuation_value():
params=params,
vf_arr=None,
)
assert val == phelps_deaton_utility(consumption=30.0, working=0, delta=1.0)
assert val == utility(consumption=30.0, working=0, delta=1.0)


def test_create_compute_conditional_continuation_value_with_discrete_model():
Expand Down Expand Up @@ -218,7 +219,7 @@ def test_create_compute_conditional_continuation_value_with_discrete_model():
params=params,
vf_arr=None,
)
assert val == phelps_deaton_utility(consumption=2, working=0, delta=1.0)
assert val == utility(consumption=2, working=0, delta=1.0)


# ======================================================================================
Expand Down Expand Up @@ -267,7 +268,7 @@ def test_create_compute_conditional_continuation_policy():
vf_arr=None,
)
assert policy == 2
assert val == phelps_deaton_utility(consumption=30.0, working=0, delta=1.0)
assert val == utility(consumption=30.0, working=0, delta=1.0)


def test_create_compute_conditional_continuation_policy_with_discrete_model():
Expand Down Expand Up @@ -311,4 +312,4 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model():
vf_arr=None,
)
assert policy == 1
assert val == phelps_deaton_utility(consumption=2, working=0, delta=1.0)
assert val == utility(consumption=2, working=0, delta=1.0)
13 changes: 0 additions & 13 deletions tests/test_long.py

This file was deleted.

Loading

0 comments on commit 39a271d

Please sign in to comment.