Skip to content

Commit

Permalink
Refactor model updating into single function call
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 16, 2024
1 parent d1e3164 commit 5151735
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions src/lcm/input_processing/process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ def process_model(model: Model) -> InternalModel:
The processed model.
"""
discrete_vars_to_update = _get_discrete_vars_with_non_index_options(model)
if discrete_vars_to_update:
model = _get_model_with_only_index_options(model, discrete_vars_to_update)
model = _convert_discrete_options_to_indices(model)

params = create_params_template(model)

Expand Down Expand Up @@ -246,29 +244,28 @@ def weight_func(*args, **kwargs):
return weight_func


def _get_discrete_vars_with_non_index_options(model: Model) -> list[str]:
gridspecs = get_gridspecs(model)
discrete_vars = []
for name, spec in gridspecs.items():
if isinstance(spec, DiscreteGrid) and list(spec.options) != list(
range(len(spec.options))
):
discrete_vars.append(name)
return discrete_vars
def _convert_discrete_options_to_indices(model: Model) -> Model:
"""Update the user model to ensure that discrete variables have index options.
Args:
model: The model as provided by the user.
def _get_model_with_only_index_options(
model: Model,
discrete_vars_to_update: list[str],
) -> Model:
"""Update the user model to ensure that discrete variables have index options."""
Returns:
The model with all discrete variables having index options.
"""
gridspecs = get_gridspecs(model)

non_index_discrete_vars = _get_discrete_vars_with_non_index_options(model)

if not non_index_discrete_vars:
return model

functions = model.functions.copy()
states = model.states.copy()
choices = model.choices.copy()

for var in discrete_vars_to_update:
for var in non_index_discrete_vars:
grid = gridspecs[var]
if isinstance(grid, DiscreteGrid):
index_grid = DiscreteGrid(options=list(range(len(grid.options))))
Expand All @@ -289,6 +286,17 @@ def _get_model_with_only_index_options(
)


def _get_discrete_vars_with_non_index_options(model: Model) -> list[str]:
gridspecs = get_gridspecs(model)
discrete_vars = []
for name, spec in gridspecs.items():
if isinstance(spec, DiscreteGrid) and list(spec.options) != list(
range(len(spec.options))
):
discrete_vars.append(name)
return discrete_vars


def _get_index_to_label_func(labels_array, name):
arg_name = f"__{name}_index__"

Expand Down

0 comments on commit 5151735

Please sign in to comment.