Skip to content

Commit

Permalink
Fix Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
mj023 committed Nov 6, 2024
1 parent 8f419d3 commit ec123cd
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 21 deletions.
48 changes: 33 additions & 15 deletions src/lcm/entry_point.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import inspect
from collections.abc import Callable
from functools import partial
import inspect
from typing import Literal, cast

import jax
Expand All @@ -17,7 +17,7 @@
get_utility_and_feasibility_function,
)
from lcm.next_state import get_next_state_function
from lcm.simulate import simulate, _as_data_frame, _compute_targets
from lcm.simulate import _as_data_frame, _compute_targets, simulate
from lcm.solve_brute import solve
from lcm.state_space import create_state_choice_space
from lcm.typing import ParamsDict
Expand Down Expand Up @@ -184,23 +184,41 @@ def get_lcm_function(
simulate_model = jax.jit(_simulate_model) if jit else _simulate_model

if targets == "solve":
def _target(*args,**kwargs):
return solve_model(*args,**kwargs)

def _target(*args, **kwargs):
return solve_model(*args, **kwargs)
elif targets == "simulate":
def _target(*args,**kwargs):
kwargs = all_as_kwargs(args,kwargs,list(inspect.signature(simulate).parameters,))
additional_targets = kwargs.get('additional_targets')
kwargs.pop('additional_targets',None)

def _target(*args, **kwargs):
kwargs = all_as_kwargs(
args, kwargs, list(inspect.signature(simulate).parameters)
)
additional_targets = kwargs.get("additional_targets")
kwargs.pop("additional_targets", None)
_simulated = simulate_model(**kwargs)
return _as_data_frame(_compute_targets(_simulated, additional_targets, _mod.functions, kwargs['params']), _mod.n_periods)
return _as_data_frame(
_compute_targets(
_simulated, additional_targets, _mod.functions, kwargs["params"]
),
_mod.n_periods,
)
elif targets == "solve_and_simulate":
def _target(*args,**kwargs):
kwargs = all_as_kwargs(args,kwargs,list(inspect.signature(simulate).parameters,))
additional_targets = kwargs.get('additional_targets')
kwargs.pop('additional_targets',None)
_solved = solve_model(kwargs['params'])

def _target(*args, **kwargs):
kwargs = all_as_kwargs(
args, kwargs, list(inspect.signature(simulate).parameters)
)
additional_targets = kwargs.get("additional_targets")
kwargs.pop("additional_targets", None)
_solved = solve_model(kwargs["params"])
_simulated = simulate_model(**kwargs, vf_arr_list=_solved)
return _as_data_frame(_compute_targets(_simulated, additional_targets, _mod.functions, kwargs['params']), _mod.n_periods)
return _as_data_frame(
_compute_targets(
_simulated, additional_targets, _mod.functions, kwargs["params"]
),
_mod.n_periods,
)

return cast(Callable, _target), _mod.params


Expand Down
9 changes: 3 additions & 6 deletions src/lcm/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,7 @@ def simulate(

logger.info("Period: %s", period)

processed = _process_simulated_data(_simulation_results)

return processed
return _process_simulated_data(_simulation_results)


def solve_continuous_problem(
Expand Down Expand Up @@ -316,10 +314,9 @@ def _compute_targets(processed_results, targets, model_functions, params):
target_func = vmap_1d(target_func, variables=variables)

kwargs = {k: v for k, v in processed_results.items() if k in variables}

return {**processed_results, **target_func(params=params, **kwargs)}
else:
return processed_results
return processed_results


def _process_simulated_data(results):
Expand Down

0 comments on commit ec123cd

Please sign in to comment.