From 0055821b4f37201be2fd4b3561341fc43766b77b Mon Sep 17 00:00:00 2001 From: mj023 Date: Mon, 21 Oct 2024 22:49:25 +0200 Subject: [PATCH] Add Jit to simulation --- src/lcm/entry_point.py | 11 ++++++----- src/lcm/simulate.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 1e560fb..e049af6 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -1,5 +1,6 @@ import functools from collections.abc import Callable +from dags import concatenate_functions from functools import partial from typing import Literal, cast @@ -15,7 +16,7 @@ get_utility_and_feasibility_function, ) from lcm.next_state import get_next_state_function -from lcm.simulate import simulate +from lcm.simulate import simulate, _as_data_frame from lcm.solve_brute import solve from lcm.state_space import create_state_choice_space from lcm.typing import ParamsDict @@ -170,7 +171,7 @@ def get_lcm_function( solve_model = jax.jit(_solve_model) if jit else _solve_model _next_state_simulate = get_next_state_function(model=_mod, target="simulate") - simulate_model = partial( + _simulate_model = partial( simulate, state_indexers=state_indexers, continuous_choice_grids=continuous_choice_grids, @@ -179,14 +180,14 @@ def get_lcm_function( next_state=jax.jit(_next_state_simulate), logger=logger, ) + simulate_model = jax.jit(_simulate_model, static_argnames="solve_model") if jit else _simulate_model if targets == "solve": _target = solve_model elif targets == "simulate": - _target = simulate_model + _target = lambda *args,**kwargs: _as_data_frame(simulate_model(*args,**kwargs),_mod.n_periods) elif targets == "solve_and_simulate": - _target = partial(simulate_model, solve_model=solve_model) - + _target = lambda *args,**kwargs: _as_data_frame(partial(simulate_model, solve_model=solve_model)(*args,**kwargs),_mod.n_periods) return cast(Callable, _target), _mod.params diff --git a/src/lcm/simulate.py b/src/lcm/simulate.py index 561cfce..686c98c 100644 --- a/src/lcm/simulate.py +++ b/src/lcm/simulate.py @@ -218,7 +218,7 @@ def simulate( ) processed = {**processed, **calculated_targets} - return _as_data_frame(processed, n_periods=n_periods) + return processed def solve_continuous_problem(