From 39a271d2b8543b400ee52460ce8e754d4f11a203 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 28 Feb 2024 12:24:04 +0100 Subject: [PATCH] Refactor example models --- .pre-commit-config.yaml | 1 + MANIFEST.in | 5 +- examples/README.md | 22 ++++ examples/long_running.py | 121 ++++++++++++++++++ pyproject.toml | 5 +- .../example_models/testing_example_models.py | 95 -------------- src/lcm/get_model.py | 2 +- src/lcm/test.py | 8 -- tests/test_entry_point.py | 21 +-- tests/test_long.py | 13 -- tests/test_model_functions.py | 5 +- .../test_models}/__init__.py | 0 .../test_models/phelps_deaton.py | 105 +++++++-------- .../test_models/stochastic.py | 39 +++++- tests/test_next_state.py | 3 +- tests/test_process_model.py | 13 +- tests/test_simulate.py | 11 +- tests/test_state_space.py | 3 +- tests/test_stochastic.py | 10 +- 19 files changed, 271 insertions(+), 211 deletions(-) create mode 100644 examples/README.md create mode 100644 examples/long_running.py delete mode 100644 src/lcm/example_models/testing_example_models.py delete mode 100644 src/lcm/test.py delete mode 100644 tests/test_long.py rename {src/lcm/example_models => tests/test_models}/__init__.py (100%) rename src/lcm/example_models/basic_example_models.py => tests/test_models/phelps_deaton.py (53%) rename src/lcm/example_models/stochastic_example_models.py => tests/test_models/stochastic.py (67%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6d1270fa..28ef322c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,6 +33,7 @@ repos: - id: name-tests-test args: - --pytest-test-first + exclude: ^tests/test_models/ - id: no-commit-to-branch args: - --branch diff --git a/MANIFEST.in b/MANIFEST.in index ce3bc0f2..1a6451bb 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -21,10 +21,11 @@ exclude *.yml exclude *.pickle exclude pytask.ini -prune src/lcm/sandbox +prune .envs +prune examples prune docs +prune src/lcm/sandbox prune tests -prune .envs global-exclude __pycache__ global-exclude *.py[co] diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..b7e51f44 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,22 @@ +# Example model specifications + +## Choosing an example + +| Example name | Description | Runtime | +| -------------- | -------------------------------------------------- | ------------- | +| `long_running` | Consumptions-savings model with health and leisure | a few minutes | + +## Running an example + +Say you want to run the [`long_running`](./long_running.py) example locally. In a Python +shell, execute: + +```python +from lcm.entry_point import get_lcm_function + +from long_running import MODEL_CONFIG, PARAMS + + +solve_model, _ = get_lcm_function(model=MODEL_CONFIG) +vf_arr = solve_model(PARAMS) +``` diff --git a/examples/long_running.py b/examples/long_running.py new file mode 100644 index 00000000..4502e763 --- /dev/null +++ b/examples/long_running.py @@ -0,0 +1,121 @@ +"""Example specification for a consumption-savings model with health and leisure.""" + +import jax.numpy as jnp + +# ====================================================================================== +# Numerical parameters and constants +# ====================================================================================== +N_GRID_POINTS = { + "states": 100, + "choices": 200, +} + +RETIREMENT_AGE = 65 + +# ====================================================================================== +# Model functions +# ====================================================================================== + + +# -------------------------------------------------------------------------------------- +# Utility function +# -------------------------------------------------------------------------------------- +def utility(consumption, working, health, sport, delta): + return jnp.log(consumption) - (delta - health) * working - sport + + +# -------------------------------------------------------------------------------------- +# Auxiliary variables +# -------------------------------------------------------------------------------------- +def working(leisure): + return 1 - leisure + + +def wage(age): + return 1 + 0.1 * age + + +def age(_period): + return _period + 18 + + +# -------------------------------------------------------------------------------------- +# State transitions +# -------------------------------------------------------------------------------------- +def next_wealth(wealth, consumption, working, wage, interest_rate): + return (1 + interest_rate) * (wealth - consumption) + wage * working + + +def next_health(health, sport, working): + return health * (1 + sport - working / 2) + + +def next_wealth_with_shock( + wealth, + consumption, + working, + wage, + wage_shock, + interest_rate, +): + return interest_rate * (wealth - consumption) + wage * wage_shock * working + + +# -------------------------------------------------------------------------------------- +# Constraints +# -------------------------------------------------------------------------------------- +def consumption_constraint(consumption, wealth): + return consumption <= wealth + + +# ====================================================================================== +# Model specification and parameters +# ====================================================================================== + +MODEL_CONFIG = { + "functions": { + "utility": utility, + "next_wealth": next_wealth, + "consumption_constraint": consumption_constraint, + "working": working, + "wage": wage, + "age": age, + "next_health": next_health, + }, + "choices": { + "leisure": {"options": [0, 1]}, + "consumption": { + "grid_type": "linspace", + "start": 1, + "stop": 100, + "n_points": N_GRID_POINTS["choices"], + }, + "sport": { + "grid_type": "linspace", + "start": 0, + "stop": 1, + "n_points": N_GRID_POINTS["choices"], + }, + }, + "states": { + "wealth": { + "grid_type": "linspace", + "start": 1, + "stop": 100, + "n_points": N_GRID_POINTS["states"], + }, + "health": { + "grid_type": "linspace", + "start": 0, + "stop": 1, + "n_points": N_GRID_POINTS["states"], + }, + }, + "n_periods": RETIREMENT_AGE - 18, +} + +PARAMS = { + "beta": 0.95, + "utility": {"delta": 0.05}, + "next_wealth": {"interest_rate": 0.05}, +} diff --git a/pyproject.toml b/pyproject.toml index bfba52a5..7bcb07db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,8 @@ write_to = "src/lcm/_version.py" [tool.ruff] target-version = "py311" fix = true + +[tool.ruff.lint] select = ["ALL"] extend-ignore = [ # missing type annotation @@ -57,6 +59,7 @@ extend-ignore = [ [tool.ruff.lint.per-file-ignores] "docs/source/conf.py" = ["E501", "ERA001", "DTZ005"] "tests/test_*.py" = ["PLR2004"] +"examples/*" = ["INP001"] [tool.ruff.lint.pydocstyle] convention = "google" @@ -75,7 +78,7 @@ markers = [ "slow: Tests that take a long time to run and are skipped in continuous integration.", "illustrative: Tests are designed for illustrative purposes", ] -norecursedirs = ["docs", ".envs"] +norecursedirs = ["docs", ".envs", "tests/test_models"] [tool.yamlfix] diff --git a/src/lcm/example_models/testing_example_models.py b/src/lcm/example_models/testing_example_models.py deleted file mode 100644 index 097207bf..00000000 --- a/src/lcm/example_models/testing_example_models.py +++ /dev/null @@ -1,95 +0,0 @@ -"""Define example model specifications.""" - -import jax.numpy as jnp - -RETIREMENT_AGE = 65 -N_CHOICE_GRID_POINTS = 200 -N_STATE_GRID_POINTS = 100 - - -def phelps_deaton_utility(consumption, working, health, sport, delta): - return jnp.log(consumption) - (delta - health) * working - sport - - -def working(retirement): - return 1 - retirement - - -def next_wealth_with_shock( - wealth, - consumption, - working, - wage, - wage_shock, - interest_rate, -): - return interest_rate * (wealth - consumption) + wage * wage_shock * working - - -def next_wealth(wealth, consumption, working, wage, interest_rate): - return (1 + interest_rate) * (wealth - consumption) + wage * working - - -def next_health(health, sport, working): - return health * (1 + sport - working / 2) - - -def consumption_constraint(consumption, wealth): - return consumption <= wealth - - -def wage(age): - return 1 + 0.1 * age - - -def age(_period): - return _period + 18 - - -PHELPS_DEATON = { - "functions": { - "utility": phelps_deaton_utility, - "next_wealth": next_wealth, - "consumption_constraint": consumption_constraint, - "working": working, - "wage": wage, - "age": age, - "next_health": next_health, - }, - "choices": { - "retirement": {"options": [0, 1]}, - "consumption": { - "grid_type": "linspace", - "start": 1, - "stop": 100, - "n_points": N_CHOICE_GRID_POINTS, - }, - "sport": { - "grid_type": "linspace", - "start": 0, - "stop": 1, - "n_points": N_CHOICE_GRID_POINTS, - }, - }, - "states": { - "wealth": { - "grid_type": "linspace", - "start": 1, - "stop": 100, - "n_points": N_STATE_GRID_POINTS, - }, - "health": { - "grid_type": "linspace", - "start": 0, - "stop": 1, - "n_points": N_STATE_GRID_POINTS, - }, - }, - "n_periods": RETIREMENT_AGE - 18, -} - -PARAMS = { - "beta": 0.95, - "utility": {"delta": 0.05}, - "next_wealth": {"interest_rate": 0.05}, -} diff --git a/src/lcm/get_model.py b/src/lcm/get_model.py index 4e0f37e4..774115f3 100644 --- a/src/lcm/get_model.py +++ b/src/lcm/get_model.py @@ -4,7 +4,7 @@ from pybaum import tree_update -from lcm.example_models.basic_example_models import ( +from tests.test_models.phelps_deaton import ( PHELPS_DEATON, PHELPS_DEATON_WITH_FILTERS, ) diff --git a/src/lcm/test.py b/src/lcm/test.py deleted file mode 100644 index dceb1619..00000000 --- a/src/lcm/test.py +++ /dev/null @@ -1,8 +0,0 @@ -from lcm.logger import get_logger - -if __name__ == "__main__": - logger = get_logger(debug_mode=False) - - logger.info("This is an info message.") - logger.debug("This is a debug message.") - logger.warning("This is a warning message.") diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index 64fa6e49..05a9b833 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -5,17 +5,18 @@ create_compute_conditional_continuation_value, get_lcm_function, ) -from lcm.example_models.basic_example_models import ( - PHELPS_DEATON, - PHELPS_DEATON_FULLY_DISCRETE, - PHELPS_DEATON_WITH_FILTERS, - phelps_deaton_utility, -) from lcm.model_functions import get_utility_and_feasibility_function from lcm.process_model import process_model from lcm.state_space import create_state_choice_space from pybaum import tree_equal, tree_map +from tests.test_models.phelps_deaton import ( + PHELPS_DEATON, + PHELPS_DEATON_FULLY_DISCRETE, + PHELPS_DEATON_WITH_FILTERS, + utility, +) + MODELS = { "simple": PHELPS_DEATON, "with_filters": PHELPS_DEATON_WITH_FILTERS, @@ -175,7 +176,7 @@ def test_create_compute_conditional_continuation_value(): params=params, vf_arr=None, ) - assert val == phelps_deaton_utility(consumption=30.0, working=0, delta=1.0) + assert val == utility(consumption=30.0, working=0, delta=1.0) def test_create_compute_conditional_continuation_value_with_discrete_model(): @@ -218,7 +219,7 @@ def test_create_compute_conditional_continuation_value_with_discrete_model(): params=params, vf_arr=None, ) - assert val == phelps_deaton_utility(consumption=2, working=0, delta=1.0) + assert val == utility(consumption=2, working=0, delta=1.0) # ====================================================================================== @@ -267,7 +268,7 @@ def test_create_compute_conditional_continuation_policy(): vf_arr=None, ) assert policy == 2 - assert val == phelps_deaton_utility(consumption=30.0, working=0, delta=1.0) + assert val == utility(consumption=30.0, working=0, delta=1.0) def test_create_compute_conditional_continuation_policy_with_discrete_model(): @@ -311,4 +312,4 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): vf_arr=None, ) assert policy == 1 - assert val == phelps_deaton_utility(consumption=2, working=0, delta=1.0) + assert val == utility(consumption=2, working=0, delta=1.0) diff --git a/tests/test_long.py b/tests/test_long.py deleted file mode 100644 index b7dacabc..00000000 --- a/tests/test_long.py +++ /dev/null @@ -1,13 +0,0 @@ -import pytest -from lcm.entry_point import get_lcm_function -from lcm.example_models.testing_example_models import PARAMS, PHELPS_DEATON - -SKIP_REASON = """The test is designed to run approximately 1 minute on a standard -laptop, such that we can differentiate the performance of running LCM on a GPU versus -on the CPU.""" - - -@pytest.mark.skip(reason=SKIP_REASON) -def test_long(): - solve_model, template = get_lcm_function(PHELPS_DEATON, targets="solve") - solve_model(PARAMS) diff --git a/tests/test_model_functions.py b/tests/test_model_functions.py index 197ee350..f2a0a7c4 100644 --- a/tests/test_model_functions.py +++ b/tests/test_model_functions.py @@ -1,6 +1,5 @@ import jax.numpy as jnp import pandas as pd -from lcm.example_models.basic_example_models import PHELPS_DEATON, phelps_deaton_utility from lcm.interfaces import Model from lcm.model_functions import ( get_combined_constraint, @@ -11,6 +10,8 @@ from lcm.state_space import create_state_choice_space from numpy.testing import assert_array_equal +from tests.test_models.phelps_deaton import PHELPS_DEATON, utility + def test_get_combined_constraint(): def f(): @@ -82,7 +83,7 @@ def test_get_utility_and_feasibility_function(): assert_array_equal( u, - phelps_deaton_utility( + utility( consumption=consumption, working=1 - retirement, delta=1.0, diff --git a/src/lcm/example_models/__init__.py b/tests/test_models/__init__.py similarity index 100% rename from src/lcm/example_models/__init__.py rename to tests/test_models/__init__.py diff --git a/src/lcm/example_models/basic_example_models.py b/tests/test_models/phelps_deaton.py similarity index 53% rename from src/lcm/example_models/basic_example_models.py rename to tests/test_models/phelps_deaton.py index 8ae9a442..06184a10 100644 --- a/src/lcm/example_models/basic_example_models.py +++ b/tests/test_models/phelps_deaton.py @@ -1,65 +1,71 @@ -"""Define example model specifications.""" +"""Example specifications of the Phelps-Deaton model.""" import jax.numpy as jnp -RETIREMENT_AGE = 65 -N_CHOICE_GRID_POINTS = 500 -N_STATE_GRID_POINTS = 100 +# ====================================================================================== +# Numerical parameters and constants +# ====================================================================================== +N_GRID_POINTS = { + "states": 100, + "choices": 200, +} -def phelps_deaton_utility_with_shock( - consumption, - working, - delta, - additive_utility_shock, -): - return jnp.log(consumption) + additive_utility_shock - delta * working +RETIREMENT_AGE = 65 +# ====================================================================================== +# Model functions +# ====================================================================================== -def phelps_deaton_utility(consumption, working, delta): + +# -------------------------------------------------------------------------------------- +# Utility functions +# -------------------------------------------------------------------------------------- +def utility(consumption, working, delta): return jnp.log(consumption) - delta * working -def phelps_deaton_utility_with_filter( +def utility_with_filter( consumption, working, delta, lagged_retirement, # noqa: ARG001 ): - return jnp.log(consumption) - delta * working + return utility(consumption=consumption, working=working, delta=delta) +# -------------------------------------------------------------------------------------- +# Auxiliary variables +# -------------------------------------------------------------------------------------- def working(retirement): return 1 - retirement -def next_wealth_with_shock( - wealth, - consumption, - working, - wage, - wage_shock, - interest_rate, -): - return interest_rate * (wealth - consumption) + wage * wage_shock * working +def wage(age): + return 1 + 0.1 * age + + +def age(_period): + return _period + 18 +# -------------------------------------------------------------------------------------- +# State transitions +# -------------------------------------------------------------------------------------- def next_wealth(wealth, consumption, working, wage, interest_rate): return (1 + interest_rate) * (wealth - consumption) + wage * working +# -------------------------------------------------------------------------------------- +# Constraints +# -------------------------------------------------------------------------------------- def consumption_constraint(consumption, wealth): return consumption <= wealth -def wage(age): - return 1 + 0.1 * age - - -def age(_period): - return _period + 18 - - +# -------------------------------------------------------------------------------------- +# Filters +# -------------------------------------------------------------------------------------- def mandatory_retirement_filter(retirement, age): return jnp.logical_or(retirement == 1, age < RETIREMENT_AGE) @@ -68,9 +74,13 @@ def absorbing_retirement_filter(retirement, lagged_retirement): return jnp.logical_or(retirement == 1, lagged_retirement == 0) +# ====================================================================================== +# Model specification and parameters +# ====================================================================================== + PHELPS_DEATON = { "functions": { - "utility": phelps_deaton_utility, + "utility": utility, "next_wealth": next_wealth, "consumption_constraint": consumption_constraint, "working": working, @@ -83,7 +93,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "grid_type": "linspace", "start": 0, "stop": 100, - "n_points": N_CHOICE_GRID_POINTS, + "n_points": N_GRID_POINTS["choices"], }, }, "states": { @@ -91,7 +101,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "grid_type": "linspace", "start": 0, "stop": 100, - "n_points": N_STATE_GRID_POINTS, + "n_points": N_GRID_POINTS["states"], }, }, "n_periods": 3, @@ -100,7 +110,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): PHELPS_DEATON_FULLY_DISCRETE = { "functions": { - "utility": phelps_deaton_utility, + "utility": utility, "next_wealth": next_wealth, "consumption_constraint": consumption_constraint, "working": working, @@ -114,33 +124,16 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "grid_type": "linspace", "start": 0, "stop": 100, - "n_points": N_STATE_GRID_POINTS, + "n_points": N_GRID_POINTS["states"], }, }, "n_periods": 3, } -PHELPS_DEATON_WITH_SHOCKS = { - **PHELPS_DEATON, - "functions": { - "utility": phelps_deaton_utility_with_shock, - "next_wealth": next_wealth_with_shock, - "consumption_constraint": consumption_constraint, - "working": working, - }, - "shocks": { - "wage_shock": "lognormal", - # special name to signal that this shock can be set to zero to calculate - # expected utility - "additive_utility_shock": "extreme_value", - }, -} - - PHELPS_DEATON_WITH_FILTERS = { "functions": { - "utility": phelps_deaton_utility_with_filter, + "utility": utility_with_filter, "next_wealth": next_wealth, "consumption_constraint": consumption_constraint, "working": working, @@ -153,7 +146,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "grid_type": "linspace", "start": 1, "stop": 100, - "n_points": N_CHOICE_GRID_POINTS, + "n_points": N_GRID_POINTS["choices"], }, }, "states": { @@ -161,7 +154,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): "grid_type": "linspace", "start": 0, "stop": 100, - "n_points": N_STATE_GRID_POINTS, + "n_points": N_GRID_POINTS["states"], }, "lagged_retirement": {"options": [0, 1]}, }, diff --git a/src/lcm/example_models/stochastic_example_models.py b/tests/test_models/stochastic.py similarity index 67% rename from src/lcm/example_models/stochastic_example_models.py rename to tests/test_models/stochastic.py index fb139424..0e8cb3ad 100644 --- a/src/lcm/example_models/stochastic_example_models.py +++ b/tests/test_models/stochastic.py @@ -1,21 +1,39 @@ -"""Define example model specifications.""" +"""Example specifications of a simple Phelps-Deaton style stochastic model.""" import jax.numpy as jnp - import lcm -N_CHOICE_GRID_POINTS = 500 -N_STATE_GRID_POINTS = 100 +# ====================================================================================== +# Numerical parameters and constants +# ====================================================================================== + +N_GRID_POINTS = { + "states": 100, + "choices": 200, +} +# ====================================================================================== +# Model functions +# ====================================================================================== + +# -------------------------------------------------------------------------------------- +# Utility function +# -------------------------------------------------------------------------------------- def utility(consumption, working, health, partner, delta, gamma): # noqa: ARG001 return jnp.log(consumption) + (gamma * health - delta) * working +# -------------------------------------------------------------------------------------- +# Deterministic state transitions +# -------------------------------------------------------------------------------------- def next_wealth(wealth, consumption, working, wage, interest_rate): return (1 + interest_rate) * (wealth - consumption) + wage * working +# -------------------------------------------------------------------------------------- +# Stochastic state transitions +# -------------------------------------------------------------------------------------- @lcm.mark.stochastic def next_health(health, partner): # noqa: ARG001 pass @@ -26,11 +44,18 @@ def next_partner(_period, working, partner): # noqa: ARG001 pass +# -------------------------------------------------------------------------------------- +# Constraints +# -------------------------------------------------------------------------------------- def consumption_constraint(consumption, wealth): return consumption <= wealth -MODEL = { +# ====================================================================================== +# Model specification and parameters +# ====================================================================================== + +MODEL_CONFIG = { "functions": { "utility": utility, "next_wealth": next_wealth, @@ -44,7 +69,7 @@ def consumption_constraint(consumption, wealth): "grid_type": "linspace", "start": 1, "stop": 100, - "n_points": N_CHOICE_GRID_POINTS, + "n_points": N_GRID_POINTS["choices"], }, }, "states": { @@ -54,7 +79,7 @@ def consumption_constraint(consumption, wealth): "grid_type": "linspace", "start": 1, "stop": 100, - "n_points": N_STATE_GRID_POINTS, + "n_points": N_GRID_POINTS["states"], }, }, "n_periods": 3, diff --git a/tests/test_next_state.py b/tests/test_next_state.py index 6ff5c5fa..572ad993 100644 --- a/tests/test_next_state.py +++ b/tests/test_next_state.py @@ -1,11 +1,12 @@ import jax.numpy as jnp import pandas as pd -from lcm.example_models.basic_example_models import PHELPS_DEATON from lcm.interfaces import Model from lcm.next_state import _get_stochastic_next_func, get_next_state_function from lcm.process_model import process_model from pybaum import tree_equal +from tests.test_models.phelps_deaton import PHELPS_DEATON + # ====================================================================================== # Solve target # ====================================================================================== diff --git a/tests/test_process_model.py b/tests/test_process_model.py index 846f3818..e5ecb684 100644 --- a/tests/test_process_model.py +++ b/tests/test_process_model.py @@ -4,12 +4,6 @@ import numpy as np import pandas as pd import pytest -from lcm.example_models.basic_example_models import ( - N_CHOICE_GRID_POINTS, - N_STATE_GRID_POINTS, - PHELPS_DEATON, - PHELPS_DEATON_WITH_FILTERS, -) from lcm.interfaces import GridSpec from lcm.mark import StochasticInfo from lcm.process_model import ( @@ -23,6 +17,13 @@ from numpy.testing import assert_array_equal from pandas.testing import assert_frame_equal +from tests.test_models.phelps_deaton import ( + N_CHOICE_GRID_POINTS, + N_STATE_GRID_POINTS, + PHELPS_DEATON, + PHELPS_DEATON_WITH_FILTERS, +) + @pytest.fixture() def user_model(): diff --git a/tests/test_simulate.py b/tests/test_simulate.py index bbfb772f..e3c0287f 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -6,11 +6,6 @@ create_compute_conditional_continuation_policy, get_lcm_function, ) -from lcm.example_models.basic_example_models import ( - N_CHOICE_GRID_POINTS, - PHELPS_DEATON, - PHELPS_DEATON_WITH_FILTERS, -) from lcm.logging import get_logger from lcm.model_functions import get_utility_and_feasibility_function from lcm.next_state import _get_next_state_function_simulation @@ -32,6 +27,12 @@ from numpy.testing import assert_array_almost_equal, assert_array_equal from pybaum import tree_equal +from tests.test_models.phelps_deaton import ( + N_CHOICE_GRID_POINTS, + PHELPS_DEATON, + PHELPS_DEATON_WITH_FILTERS, +) + # ====================================================================================== # Test simulate using raw inputs # ====================================================================================== diff --git a/tests/test_state_space.py b/tests/test_state_space.py index 8f9bc33c..a5a5b8e0 100644 --- a/tests/test_state_space.py +++ b/tests/test_state_space.py @@ -2,7 +2,6 @@ import numpy as np import pandas as pd import pytest -from lcm.example_models.basic_example_models import PHELPS_DEATON_WITH_FILTERS from lcm.interfaces import Model from lcm.process_model import process_model from lcm.state_space import ( @@ -14,6 +13,8 @@ ) from numpy.testing import assert_array_almost_equal as aaae +from tests.test_models.phelps_deaton import PHELPS_DEATON_WITH_FILTERS + def test_create_state_choice_space(): _model = process_model(PHELPS_DEATON_WITH_FILTERS) diff --git a/tests/test_stochastic.py b/tests/test_stochastic.py index 2a0223cc..636247a2 100644 --- a/tests/test_stochastic.py +++ b/tests/test_stochastic.py @@ -5,7 +5,8 @@ from lcm.entry_point import ( get_lcm_function, ) -from lcm.example_models.stochastic_example_models import MODEL, PARAMS + +from tests.test_models.stochastic import MODEL_CONFIG, PARAMS # ====================================================================================== # Simulate @@ -13,7 +14,10 @@ def test_get_lcm_function_with_simulate_target(): - simulate_model, _ = get_lcm_function(model=MODEL, targets="solve_and_simulate") + simulate_model, _ = get_lcm_function( + model=MODEL_CONFIG, + targets="solve_and_simulate", + ) res = simulate_model( PARAMS, @@ -47,7 +51,7 @@ def test_get_lcm_function_with_simulate_target(): def test_get_lcm_function_with_solve_target(): - solve_model, _ = get_lcm_function(model=MODEL) + solve_model, _ = get_lcm_function(model=MODEL_CONFIG) solve_model(PARAMS)