Skip to content

Commit

Permalink
Remove Walrus operator
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 5, 2024
1 parent 6f47653 commit 1bae2e3
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 17 deletions.
6 changes: 4 additions & 2 deletions src/lcm/create_params_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def _create_stochastic_transition_params(
# ==================================================================================
discrete_state_vars = set(variable_info.query("is_state & is_discrete").index)

if invalid := set(stochastic_variables) - discrete_state_vars:
invalid = set(stochastic_variables) - discrete_state_vars
if invalid:
raise ValueError(
f"The following variables are stochastic, but are not discrete state "
f"variables: {invalid}. This is currently not supported.",
Expand All @@ -135,7 +136,8 @@ def _create_stochastic_transition_params(

# If there are invalid dependencies, store them in a dictionary and continue
# with the next variable to collect as many invalid arguments as possible.
if invalid := set(dependencies) - valid_vars:
invalid = set(dependencies) - valid_vars
if invalid:
invalid_dependencies[var] = invalid
else:
# Get the dimensions of variables that influence the stochastic variable
Expand Down
15 changes: 10 additions & 5 deletions src/lcm/dispatchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,20 @@ def spacemap(
"""
# Check inputs and prepare function
# ==================================================================================
if overlap := set(dense_vars).intersection(sparse_vars):
overlap = set(dense_vars).intersection(sparse_vars)
if overlap:
raise ValueError(
f"Dense and sparse variables must be disjoint. Overlap: {overlap}",
)

if duplicates := {v for v in dense_vars if dense_vars.count(v) > 1}:
duplicates = {v for v in dense_vars if dense_vars.count(v) > 1}
if duplicates:
raise ValueError(
f"Same argument provided more than once in dense variables: {duplicates}",
)

if duplicates := {v for v in sparse_vars if sparse_vars.count(v) > 1}:
duplicates = {v for v in sparse_vars if sparse_vars.count(v) > 1}
if duplicates:
raise ValueError(
f"Same argument provided more than once in sparse variables: {duplicates}",
)
Expand Down Expand Up @@ -116,7 +119,8 @@ def vmap_1d(
described above but there might be additional dimensions.
"""
if duplicates := {v for v in variables if variables.count(v) > 1}:
duplicates = {v for v in variables if variables.count(v) > 1}
if duplicates:
raise ValueError(
f"Same argument provided more than once in variables: {duplicates}",
)
Expand Down Expand Up @@ -179,7 +183,8 @@ def productmap(func: F, variables: list[str]) -> F:
"""
func = allow_args(func) # jax.vmap cannot deal with keyword-only arguments

if duplicates := {v for v in variables if variables.count(v) > 1}:
duplicates = {v for v in variables if variables.count(v) > 1}
if duplicates:
raise ValueError(
f"Same argument provided more than once in variables: {duplicates}",
)
Expand Down
6 changes: 4 additions & 2 deletions src/lcm/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ def func_with_only_kwargs(*args, **kwargs):
),
)

if extra := set(kwargs).difference(parameters):
extra = set(kwargs).difference(parameters)
if extra:
raise ValueError(
f"Expected arguments: {list(parameters)}, got extra: {extra}",
)

if missing := set(parameters).difference(kwargs):
missing = set(parameters).difference(kwargs)
if missing:
raise ValueError(
f"Expected arguments: {list(parameters)}, missing: {missing}",
)
Expand Down
9 changes: 6 additions & 3 deletions src/lcm/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __post_init__(self) -> None:
"list or tuple",
)

if errors := _validate_discrete_grid(self.options):
errors = _validate_discrete_grid(self.options)
if errors:
msg = format_messages(errors)
raise GridInitializationError(msg)

Expand Down Expand Up @@ -177,10 +178,12 @@ def _validate_continuous_grid(
"""
error_messages = []

if not (valid_start_type := isinstance(start, int | float)):
valid_start_type = isinstance(start, int | float)
if not valid_start_type:
error_messages.append("start must be a scalar int or float value")

if not (valid_stop_type := isinstance(stop, int | float)):
valid_stop_type = isinstance(stop, int | float)
if not valid_stop_type:
error_messages.append("stop must be a scalar int or float value")

if not isinstance(n_points, int) or n_points < 1:
Expand Down
16 changes: 11 additions & 5 deletions src/lcm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@ class Model:
states: dict[str, Grid] = field(default_factory=dict)

def __post_init__(self) -> None:
if type_errors := _validate_attribute_types(self):
type_errors = _validate_attribute_types(self)

if type_errors:
msg = format_messages(type_errors)
raise ModelInitilizationError(msg)

if logical_errors := _validate_logical_consistency(self):
logical_errors = _validate_logical_consistency(self)

if logical_errors:
msg = format_messages(logical_errors)
raise ModelInitilizationError(msg)

Expand Down Expand Up @@ -98,16 +102,18 @@ def _validate_logical_consistency(model: Model) -> list[str]:
"in the functions dictionary.",
)

if states_without_next_func := [
states_without_next_func = [
state for state in model.states if f"next_{state}" not in model.functions
]:
]
if states_without_next_func:
error_messages.append(
"Each state must have a corresponding next state function. For the "
"following states, no next state function was found: "
f"{states_without_next_func}.",
)

if states_and_choices_overlap := set(model.states) & set(model.choices):
states_and_choices_overlap = set(model.states) & set(model.choices)
if states_and_choices_overlap:
error_messages.append(
"States and choices cannot have overlapping names. The following names "
f"are used in both states and choices: {states_and_choices_overlap}.",
Expand Down

0 comments on commit 1bae2e3

Please sign in to comment.