From 1bae2e380b21d1989d981c00349a64721b9143a6 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 5 Sep 2024 16:24:16 +0200 Subject: [PATCH] Remove Walrus operator --- src/lcm/create_params_template.py | 6 ++++-- src/lcm/dispatchers.py | 15 ++++++++++----- src/lcm/functools.py | 6 ++++-- src/lcm/grids.py | 9 ++++++--- src/lcm/model.py | 16 +++++++++++----- 5 files changed, 35 insertions(+), 17 deletions(-) diff --git a/src/lcm/create_params_template.py b/src/lcm/create_params_template.py index 6e20646..b0e1848 100644 --- a/src/lcm/create_params_template.py +++ b/src/lcm/create_params_template.py @@ -113,7 +113,8 @@ def _create_stochastic_transition_params( # ================================================================================== discrete_state_vars = set(variable_info.query("is_state & is_discrete").index) - if invalid := set(stochastic_variables) - discrete_state_vars: + invalid = set(stochastic_variables) - discrete_state_vars + if invalid: raise ValueError( f"The following variables are stochastic, but are not discrete state " f"variables: {invalid}. This is currently not supported.", @@ -135,7 +136,8 @@ def _create_stochastic_transition_params( # If there are invalid dependencies, store them in a dictionary and continue # with the next variable to collect as many invalid arguments as possible. - if invalid := set(dependencies) - valid_vars: + invalid = set(dependencies) - valid_vars + if invalid: invalid_dependencies[var] = invalid else: # Get the dimensions of variables that influence the stochastic variable diff --git a/src/lcm/dispatchers.py b/src/lcm/dispatchers.py index ebe941a..d6eed9b 100644 --- a/src/lcm/dispatchers.py +++ b/src/lcm/dispatchers.py @@ -48,17 +48,20 @@ def spacemap( """ # Check inputs and prepare function # ================================================================================== - if overlap := set(dense_vars).intersection(sparse_vars): + overlap = set(dense_vars).intersection(sparse_vars) + if overlap: raise ValueError( f"Dense and sparse variables must be disjoint. Overlap: {overlap}", ) - if duplicates := {v for v in dense_vars if dense_vars.count(v) > 1}: + duplicates = {v for v in dense_vars if dense_vars.count(v) > 1} + if duplicates: raise ValueError( f"Same argument provided more than once in dense variables: {duplicates}", ) - if duplicates := {v for v in sparse_vars if sparse_vars.count(v) > 1}: + duplicates = {v for v in sparse_vars if sparse_vars.count(v) > 1} + if duplicates: raise ValueError( f"Same argument provided more than once in sparse variables: {duplicates}", ) @@ -116,7 +119,8 @@ def vmap_1d( described above but there might be additional dimensions. """ - if duplicates := {v for v in variables if variables.count(v) > 1}: + duplicates = {v for v in variables if variables.count(v) > 1} + if duplicates: raise ValueError( f"Same argument provided more than once in variables: {duplicates}", ) @@ -179,7 +183,8 @@ def productmap(func: F, variables: list[str]) -> F: """ func = allow_args(func) # jax.vmap cannot deal with keyword-only arguments - if duplicates := {v for v in variables if variables.count(v) > 1}: + duplicates = {v for v in variables if variables.count(v) > 1} + if duplicates: raise ValueError( f"Same argument provided more than once in variables: {duplicates}", ) diff --git a/src/lcm/functools.py b/src/lcm/functools.py index 811fa87..58a1d9b 100644 --- a/src/lcm/functools.py +++ b/src/lcm/functools.py @@ -41,12 +41,14 @@ def func_with_only_kwargs(*args, **kwargs): ), ) - if extra := set(kwargs).difference(parameters): + extra = set(kwargs).difference(parameters) + if extra: raise ValueError( f"Expected arguments: {list(parameters)}, got extra: {extra}", ) - if missing := set(parameters).difference(kwargs): + missing = set(parameters).difference(kwargs) + if missing: raise ValueError( f"Expected arguments: {list(parameters)}, missing: {missing}", ) diff --git a/src/lcm/grids.py b/src/lcm/grids.py index 06cd457..15c89bc 100644 --- a/src/lcm/grids.py +++ b/src/lcm/grids.py @@ -39,7 +39,8 @@ def __post_init__(self) -> None: "list or tuple", ) - if errors := _validate_discrete_grid(self.options): + errors = _validate_discrete_grid(self.options) + if errors: msg = format_messages(errors) raise GridInitializationError(msg) @@ -177,10 +178,12 @@ def _validate_continuous_grid( """ error_messages = [] - if not (valid_start_type := isinstance(start, int | float)): + valid_start_type = isinstance(start, int | float) + if not valid_start_type: error_messages.append("start must be a scalar int or float value") - if not (valid_stop_type := isinstance(stop, int | float)): + valid_stop_type = isinstance(stop, int | float) + if not valid_stop_type: error_messages.append("stop must be a scalar int or float value") if not isinstance(n_points, int) or n_points < 1: diff --git a/src/lcm/model.py b/src/lcm/model.py index 5ee7456..c2ba24e 100644 --- a/src/lcm/model.py +++ b/src/lcm/model.py @@ -31,11 +31,15 @@ class Model: states: dict[str, Grid] = field(default_factory=dict) def __post_init__(self) -> None: - if type_errors := _validate_attribute_types(self): + type_errors = _validate_attribute_types(self) + + if type_errors: msg = format_messages(type_errors) raise ModelInitilizationError(msg) - if logical_errors := _validate_logical_consistency(self): + logical_errors = _validate_logical_consistency(self) + + if logical_errors: msg = format_messages(logical_errors) raise ModelInitilizationError(msg) @@ -98,16 +102,18 @@ def _validate_logical_consistency(model: Model) -> list[str]: "in the functions dictionary.", ) - if states_without_next_func := [ + states_without_next_func = [ state for state in model.states if f"next_{state}" not in model.functions - ]: + ] + if states_without_next_func: error_messages.append( "Each state must have a corresponding next state function. For the " "following states, no next state function was found: " f"{states_without_next_func}.", ) - if states_and_choices_overlap := set(model.states) & set(model.choices): + states_and_choices_overlap = set(model.states) & set(model.choices) + if states_and_choices_overlap: error_messages.append( "States and choices cannot have overlapping names. The following names " f"are used in both states and choices: {states_and_choices_overlap}.",