diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index da98e44..fe03e22 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -29,7 +29,7 @@ def get_lcm_function( targets: Literal["solve", "simulate", "solve_and_simulate"] = "solve", *, debug_mode: bool = True, - jit: bool = False, + jit: bool = True, ) -> tuple[Callable, ParamsDict]: """Entry point for users to get high level functions generated by lcm. diff --git a/tests/test_regression_test.py b/tests/test_regression_test.py index 52c360a..3f4ab3e 100644 --- a/tests/test_regression_test.py +++ b/tests/test_regression_test.py @@ -48,4 +48,4 @@ def test_regression_test(): # Compare # ================================================================================== aaae(expected_solve, got_solve, decimal=5) - assert_frame_equal(expected_simulate, got_simulate) + assert_frame_equal(expected_simulate, got_simulate, check_like=True) diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 859aa99..6d0c1b1 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -94,8 +94,8 @@ def test_simulate_using_raw_inputs(simulate_inputs): **simulate_inputs, ) - assert_array_equal(got.loc[0, :]["retirement"], 1) - assert_array_almost_equal(got.loc[0, :]["consumption"], jnp.array([1.0, 50.400803])) + assert_array_equal(got["retirement"], 1) + assert_array_almost_equal(got["consumption"], jnp.array([1.0, 50.400803])) # ====================================================================================== @@ -336,6 +336,7 @@ def f_b(b, params): # noqa: ARG001 params={"disutility_of_work": -1.0}, ) expected = { + **processed_results, "fa": jnp.arange(3) - 1.0, "fb": 1 + jnp.arange(3), }