diff --git a/src/lcm/input_processing/discrete_state_conversion.py b/src/lcm/input_processing/discrete_state_conversion.py index 5d3216f..1ac6437 100644 --- a/src/lcm/input_processing/discrete_state_conversion.py +++ b/src/lcm/input_processing/discrete_state_conversion.py @@ -1,4 +1,3 @@ -import inspect from collections.abc import Callable from dataclasses import dataclass, field @@ -100,8 +99,8 @@ def convert_discrete_options_to_indices( 1. Remove the variable from the states or choices dictionary 2. Replace it with a new state or choice with index options (__{var}_index__) - 3. Add a function that maps the index options to the original options - 4. Add updated next functions (if the variable was a state variable) + 3. Add updated next functions (if the variable was a state variable) + 4. Add a function that maps the index options to the original options Args: model: The model as provided by the user. @@ -116,43 +115,34 @@ def convert_discrete_options_to_indices( non_index_discrete_vars = _get_discrete_vars_with_non_index_options(model) + # fast path if not non_index_discrete_vars: - # fast path return model, Converter() functions = model.functions.copy() states = model.states.copy() choices = model.choices.copy() - # Update next functions (needs to be done before updating the grids, otherwise the - # already updated state variables are being used) - # ---------------------------------------------------------------------------------- - non_index_states = [s for s in states if s in non_index_discrete_vars] - - for state in model.states: - next_func = model.functions[f"next_{state}"] - must_be_updated = _func_depends_on(next_func, depends_on=non_index_states) - if must_be_updated: - functions.pop(f"next_{state}") - - functions[f"next___{state}_index__"] = _get_next_func_of_index_var( - next_func=next_func, - variables=non_index_states, - ) - # Update grids # ---------------------------------------------------------------------------------- for var in non_index_discrete_vars: grid: DiscreteGrid = gridspecs[var] # type: ignore[assignment] index_grid = DiscreteGrid(options=list(range(len(grid.options)))) - if var in states: + if var in model.states: states.pop(var) states[f"__{var}_index__"] = index_grid else: choices.pop(var) choices[f"__{var}_index__"] = index_grid + # Update next functions + # ---------------------------------------------------------------------------------- + non_index_states = [s for s in model.states if s in non_index_discrete_vars] + + for var in non_index_states: + functions[f"next___{var}_index__"] = functions.pop(f"next_{var}") + # Add index to label functions # ---------------------------------------------------------------------------------- index_to_label_funcs = { @@ -189,28 +179,6 @@ def convert_discrete_options_to_indices( return new_model, converter -def _get_next_func_of_index_var(next_func: Callable, variables: list[str]) -> Callable: - """Create next function for corresponding index variable.""" - arg_names = list(inspect.signature(next_func).parameters) - - relevant_vars = [var for var in variables if var in arg_names] - - if not relevant_vars: - return next_func - - for var in relevant_vars: - arg_names[arg_names.index(var)] = f"__{var}_index__" - - @with_signature(args=arg_names) - def next_func_of_index_var(*args, **kwargs): - kwargs = all_as_kwargs(args, kwargs, arg_names=arg_names) - for var in relevant_vars: - kwargs[var] = kwargs.pop(f"__{var}_index__") - return next_func(**kwargs) - - return next_func_of_index_var - - def _get_discrete_vars_with_non_index_options(model: Model) -> list[str]: """Get discrete variables with non-index options. @@ -274,9 +242,3 @@ def label_to_index(*args, **kwargs): return jnp.argmax(data[:, None] == labels_array[None, :], axis=1) return label_to_index - - -def _func_depends_on(func: Callable, depends_on: list[str]) -> bool: - """Check if any function argument is in the list depends_on.""" - arg_names = list(inspect.signature(func).parameters) - return any(arg in depends_on for arg in arg_names) diff --git a/tests/input_processing/test_discrete_state_conversion.py b/tests/input_processing/test_discrete_state_conversion.py index 92de16b..f34efe2 100644 --- a/tests/input_processing/test_discrete_state_conversion.py +++ b/tests/input_processing/test_discrete_state_conversion.py @@ -7,10 +7,8 @@ from lcm import DiscreteGrid from lcm.input_processing.discrete_state_conversion import ( - _func_depends_on, _get_discrete_vars_with_non_index_options, _get_index_to_label_func, - _get_next_func_of_index_var, convert_discrete_options_to_indices, ) @@ -73,20 +71,3 @@ def test_convert_discrete_options_to_indices(model): assert_array_equal(got.states["__c_index__"], DiscreteGrid([0, 1])) assert got.functions["c"](0) == 1 assert got.functions["c"](1) == 10 - - -def test_func_depends_on(): - def foo(a, b): - pass - - assert _func_depends_on(foo, depends_on=["a", "b"]) - assert not _func_depends_on(foo, depends_on=["c"]) - - -def test_get_next_func_of_index_var(): - def next_a(a): - return a - - got = _get_next_func_of_index_var(next_a, variables=["a"]) - assert got(__a_index__=0) == 0 - assert got(2) == 2