Skip to content

Commit

Permalink
Remove unecessary complexity
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 17, 2024
1 parent 4b3143d commit 27b8914
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 68 deletions.
60 changes: 11 additions & 49 deletions src/lcm/input_processing/discrete_state_conversion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
from collections.abc import Callable
from dataclasses import dataclass, field

Expand Down Expand Up @@ -100,8 +99,8 @@ def convert_discrete_options_to_indices(
1. Remove the variable from the states or choices dictionary
2. Replace it with a new state or choice with index options (__{var}_index__)
3. Add a function that maps the index options to the original options
4. Add updated next functions (if the variable was a state variable)
3. Add updated next functions (if the variable was a state variable)
4. Add a function that maps the index options to the original options
Args:
model: The model as provided by the user.
Expand All @@ -116,43 +115,34 @@ def convert_discrete_options_to_indices(

non_index_discrete_vars = _get_discrete_vars_with_non_index_options(model)

# fast path
if not non_index_discrete_vars:
# fast path
return model, Converter()

functions = model.functions.copy()
states = model.states.copy()
choices = model.choices.copy()

# Update next functions (needs to be done before updating the grids, otherwise the
# already updated state variables are being used)
# ----------------------------------------------------------------------------------
non_index_states = [s for s in states if s in non_index_discrete_vars]

for state in model.states:
next_func = model.functions[f"next_{state}"]
must_be_updated = _func_depends_on(next_func, depends_on=non_index_states)
if must_be_updated:
functions.pop(f"next_{state}")

functions[f"next___{state}_index__"] = _get_next_func_of_index_var(
next_func=next_func,
variables=non_index_states,
)

# Update grids
# ----------------------------------------------------------------------------------
for var in non_index_discrete_vars:
grid: DiscreteGrid = gridspecs[var] # type: ignore[assignment]
index_grid = DiscreteGrid(options=list(range(len(grid.options))))

if var in states:
if var in model.states:
states.pop(var)
states[f"__{var}_index__"] = index_grid
else:
choices.pop(var)
choices[f"__{var}_index__"] = index_grid

# Update next functions
# ----------------------------------------------------------------------------------
non_index_states = [s for s in model.states if s in non_index_discrete_vars]

for var in non_index_states:
functions[f"next___{var}_index__"] = functions.pop(f"next_{var}")

# Add index to label functions
# ----------------------------------------------------------------------------------
index_to_label_funcs = {
Expand Down Expand Up @@ -189,28 +179,6 @@ def convert_discrete_options_to_indices(
return new_model, converter


def _get_next_func_of_index_var(next_func: Callable, variables: list[str]) -> Callable:
"""Create next function for corresponding index variable."""
arg_names = list(inspect.signature(next_func).parameters)

relevant_vars = [var for var in variables if var in arg_names]

if not relevant_vars:
return next_func

for var in relevant_vars:
arg_names[arg_names.index(var)] = f"__{var}_index__"

@with_signature(args=arg_names)
def next_func_of_index_var(*args, **kwargs):
kwargs = all_as_kwargs(args, kwargs, arg_names=arg_names)
for var in relevant_vars:
kwargs[var] = kwargs.pop(f"__{var}_index__")
return next_func(**kwargs)

return next_func_of_index_var


def _get_discrete_vars_with_non_index_options(model: Model) -> list[str]:
"""Get discrete variables with non-index options.
Expand Down Expand Up @@ -274,9 +242,3 @@ def label_to_index(*args, **kwargs):
return jnp.argmax(data[:, None] == labels_array[None, :], axis=1)

return label_to_index


def _func_depends_on(func: Callable, depends_on: list[str]) -> bool:
"""Check if any function argument is in the list depends_on."""
arg_names = list(inspect.signature(func).parameters)
return any(arg in depends_on for arg in arg_names)
19 changes: 0 additions & 19 deletions tests/input_processing/test_discrete_state_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@

from lcm import DiscreteGrid
from lcm.input_processing.discrete_state_conversion import (
_func_depends_on,
_get_discrete_vars_with_non_index_options,
_get_index_to_label_func,
_get_next_func_of_index_var,
convert_discrete_options_to_indices,
)

Expand Down Expand Up @@ -73,20 +71,3 @@ def test_convert_discrete_options_to_indices(model):
assert_array_equal(got.states["__c_index__"], DiscreteGrid([0, 1]))
assert got.functions["c"](0) == 1
assert got.functions["c"](1) == 10


def test_func_depends_on():
def foo(a, b):
pass

assert _func_depends_on(foo, depends_on=["a", "b"])
assert not _func_depends_on(foo, depends_on=["c"])


def test_get_next_func_of_index_var():
def next_a(a):
return a

got = _get_next_func_of_index_var(next_a, variables=["a"])
assert got(__a_index__=0) == 0
assert got(2) == 2

0 comments on commit 27b8914

Please sign in to comment.