diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index fe03e22..37ec5df 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -1,7 +1,7 @@ import functools +import inspect from collections.abc import Callable from functools import partial -import inspect from typing import Literal, cast import jax @@ -17,7 +17,7 @@ get_utility_and_feasibility_function, ) from lcm.next_state import get_next_state_function -from lcm.simulate import simulate, _as_data_frame, _compute_targets +from lcm.simulate import _as_data_frame, _compute_targets, simulate from lcm.solve_brute import solve from lcm.state_space import create_state_choice_space from lcm.typing import ParamsDict @@ -184,23 +184,41 @@ def get_lcm_function( simulate_model = jax.jit(_simulate_model) if jit else _simulate_model if targets == "solve": - def _target(*args,**kwargs): - return solve_model(*args,**kwargs) + + def _target(*args, **kwargs): + return solve_model(*args, **kwargs) elif targets == "simulate": - def _target(*args,**kwargs): - kwargs = all_as_kwargs(args,kwargs,list(inspect.signature(simulate).parameters,)) - additional_targets = kwargs.get('additional_targets') - kwargs.pop('additional_targets',None) + + def _target(*args, **kwargs): + kwargs = all_as_kwargs( + args, kwargs, list(inspect.signature(simulate).parameters) + ) + additional_targets = kwargs.get("additional_targets") + kwargs.pop("additional_targets", None) _simulated = simulate_model(**kwargs) - return _as_data_frame(_compute_targets(_simulated, additional_targets, _mod.functions, kwargs['params']), _mod.n_periods) + return _as_data_frame( + _compute_targets( + _simulated, additional_targets, _mod.functions, kwargs["params"] + ), + _mod.n_periods, + ) elif targets == "solve_and_simulate": - def _target(*args,**kwargs): - kwargs = all_as_kwargs(args,kwargs,list(inspect.signature(simulate).parameters,)) - additional_targets = kwargs.get('additional_targets') - kwargs.pop('additional_targets',None) - _solved = solve_model(kwargs['params']) + + def _target(*args, **kwargs): + kwargs = all_as_kwargs( + args, kwargs, list(inspect.signature(simulate).parameters) + ) + additional_targets = kwargs.get("additional_targets") + kwargs.pop("additional_targets", None) + _solved = solve_model(kwargs["params"]) _simulated = simulate_model(**kwargs, vf_arr_list=_solved) - return _as_data_frame(_compute_targets(_simulated, additional_targets, _mod.functions, kwargs['params']), _mod.n_periods) + return _as_data_frame( + _compute_targets( + _simulated, additional_targets, _mod.functions, kwargs["params"] + ), + _mod.n_periods, + ) + return cast(Callable, _target), _mod.params diff --git a/src/lcm/simulate.py b/src/lcm/simulate.py index 468542e..cc2283e 100644 --- a/src/lcm/simulate.py +++ b/src/lcm/simulate.py @@ -201,9 +201,7 @@ def simulate( logger.info("Period: %s", period) - processed = _process_simulated_data(_simulation_results) - - return processed + return _process_simulated_data(_simulation_results) def solve_continuous_problem( @@ -316,10 +314,9 @@ def _compute_targets(processed_results, targets, model_functions, params): target_func = vmap_1d(target_func, variables=variables) kwargs = {k: v for k, v in processed_results.items() if k in variables} - + return {**processed_results, **target_func(params=params, **kwargs)} - else: - return processed_results + return processed_results def _process_simulated_data(results):