From ff093d99eab14896ef5f920fea10562e299b99e9 Mon Sep 17 00:00:00 2001 From: mj023 Date: Thu, 31 Oct 2024 03:59:13 +0100 Subject: [PATCH] Move Compute Targets --- src/lcm/entry_point.py | 27 +++++++++++++++++------ src/lcm/simulate.py | 49 ++++++++++++++++-------------------------- 2 files changed, 39 insertions(+), 37 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index e049af6..da98e44 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -1,7 +1,7 @@ import functools from collections.abc import Callable -from dags import concatenate_functions from functools import partial +import inspect from typing import Literal, cast import jax @@ -10,13 +10,14 @@ from lcm.argmax import argmax from lcm.discrete_problem import get_solve_discrete_problem from lcm.dispatchers import productmap +from lcm.functools import all_as_kwargs from lcm.input_processing import process_model from lcm.logging import get_logger from lcm.model_functions import ( get_utility_and_feasibility_function, ) from lcm.next_state import get_next_state_function -from lcm.simulate import simulate, _as_data_frame +from lcm.simulate import simulate, _as_data_frame, _compute_targets from lcm.solve_brute import solve from lcm.state_space import create_state_choice_space from lcm.typing import ParamsDict @@ -28,7 +29,7 @@ def get_lcm_function( targets: Literal["solve", "simulate", "solve_and_simulate"] = "solve", *, debug_mode: bool = True, - jit: bool = True, + jit: bool = False, ) -> tuple[Callable, ParamsDict]: """Entry point for users to get high level functions generated by lcm. @@ -180,14 +181,26 @@ def get_lcm_function( next_state=jax.jit(_next_state_simulate), logger=logger, ) - simulate_model = jax.jit(_simulate_model, static_argnames="solve_model") if jit else _simulate_model + simulate_model = jax.jit(_simulate_model) if jit else _simulate_model if targets == "solve": - _target = solve_model + def _target(*args,**kwargs): + return solve_model(*args,**kwargs) elif targets == "simulate": - _target = lambda *args,**kwargs: _as_data_frame(simulate_model(*args,**kwargs),_mod.n_periods) + def _target(*args,**kwargs): + kwargs = all_as_kwargs(args,kwargs,list(inspect.signature(simulate).parameters,)) + additional_targets = kwargs.get('additional_targets') + kwargs.pop('additional_targets',None) + _simulated = simulate_model(**kwargs) + return _as_data_frame(_compute_targets(_simulated, additional_targets, _mod.functions, kwargs['params']), _mod.n_periods) elif targets == "solve_and_simulate": - _target = lambda *args,**kwargs: _as_data_frame(partial(simulate_model, solve_model=solve_model)(*args,**kwargs),_mod.n_periods) + def _target(*args,**kwargs): + kwargs = all_as_kwargs(args,kwargs,list(inspect.signature(simulate).parameters,)) + additional_targets = kwargs.get('additional_targets') + kwargs.pop('additional_targets',None) + _solved = solve_model(kwargs['params']) + _simulated = simulate_model(**kwargs, vf_arr_list=_solved) + return _as_data_frame(_compute_targets(_simulated, additional_targets, _mod.functions, kwargs['params']), _mod.n_periods) return cast(Callable, _target), _mod.params diff --git a/src/lcm/simulate.py b/src/lcm/simulate.py index 686c98c..468542e 100644 --- a/src/lcm/simulate.py +++ b/src/lcm/simulate.py @@ -21,9 +21,7 @@ def simulate( model: InternalModel, next_state, logger, - solve_model=None, vf_arr_list=None, - additional_targets=None, seed=12345, ): """Simulate the model forward in time. @@ -59,13 +57,9 @@ def simulate( """ if vf_arr_list is None: - if solve_model is None: - raise ValueError( - "You need to provide either vf_arr_list or solve_model.", - ) - # We do not need to convert the params here, because the solve_model function - # will do it. - vf_arr_list = solve_model(params) + raise ValueError( + "You need to provide either vf_arr_list or solve_model.", + ) logger.info("Starting simulation") @@ -209,15 +203,6 @@ def simulate( processed = _process_simulated_data(_simulation_results) - if additional_targets is not None: - calculated_targets = _compute_targets( - processed, - targets=additional_targets, - model_functions=model.functions, - params=params, - ) - processed = {**processed, **calculated_targets} - return processed @@ -316,21 +301,25 @@ def _compute_targets(processed_results, targets, model_functions, params): dict: Dict with computed targets. """ - target_func = concatenate_functions( - functions=model_functions, - targets=targets, - return_type="dict", - ) + if targets is not None: + target_func = concatenate_functions( + functions=model_functions, + targets=targets, + return_type="dict", + ) - # get list of variables over which we want to vectorize the target function - variables = [ - p for p in list(inspect.signature(target_func).parameters) if p != "params" - ] + # get list of variables over which we want to vectorize the target function + variables = [ + p for p in list(inspect.signature(target_func).parameters) if p != "params" + ] - target_func = vmap_1d(target_func, variables=variables) + target_func = vmap_1d(target_func, variables=variables) - kwargs = {k: v for k, v in processed_results.items() if k in variables} - return target_func(params=params, **kwargs) + kwargs = {k: v for k, v in processed_results.items() if k in variables} + + return {**processed_results, **target_func(params=params, **kwargs)} + else: + return processed_results def _process_simulated_data(results):