Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Feb 5, 2024
1 parent f253c01 commit fb133e0
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
4 changes: 1 addition & 3 deletions src/lcm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import jax

from lcm import mark

jax.config.update("jax_platform_name", "cpu")
# jax.config.update("jax_platform_name", "cpu")


__all__ = ["mark"]
40 changes: 40 additions & 0 deletions test_long.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Testing against the analytical solution by Iskhakov et al (2017)."""
import jax
import numpy as np
from lcm.entry_point import get_lcm_function
from lcm.get_model import get_model

TEST_CASES = {
"iskhakov_2017_five_periods": get_model("iskhakov_2017_five_periods"),
"iskhakov_2017_low_delta": get_model("iskhakov_2017_low_delta"),
}


def mean_square_error(x, y, axis=None):
return np.mean((x - y) ** 2, axis=axis)


def test_analytical_solution(model_name, model_config):
"""Test that the numerical solution matches the analytical solution.
The analytical solution is from Iskhakov et al (2017) and is generated
in the development repository: github.com/opensourceeconomics/lcm-dev.
"""
# Compute LCM solution
# ==================================================================================
solve_model, _ = get_lcm_function(model=model_config.model)

vf_arr_list = solve_model(params=model_config.params)
_numerical = np.stack(vf_arr_list)
numerical = {
"worker": _numerical[:, 0, :],
"retired": _numerical[:, 1, :],
}


with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
test_analytical_solution(
"iskhakov_2017_five_periods",
get_model("iskhakov_2017_five_periods"),
)

0 comments on commit fb133e0

Please sign in to comment.