Skip to content

Commit

Permalink
Make validation easier to read.
Browse files Browse the repository at this point in the history
  • Loading branch information
hmgaudecker committed Sep 19, 2024
1 parent efb91ae commit ff42633
Showing 1 changed file with 14 additions and 17 deletions.
31 changes: 14 additions & 17 deletions src/lcm/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,10 @@ def __init__(self, category_class: type) -> None:
values.
"""
if not is_dataclass(category_class):
raise GridInitializationError(
"category_class must be a dataclass with scalar int or float fields, "
f"but is {category_class}."
)
_validate_discrete_grid(category_class)

names_and_values = _get_field_names_and_values(category_class)

errors = _validate_discrete_grid(names_and_values)
if errors:
msg = format_messages(errors)
raise GridInitializationError(msg)

self.__categories = list(names_and_values.keys())
self.__codes = list(names_and_values.values())

Expand Down Expand Up @@ -159,17 +150,21 @@ def get_coordinate(self, value: Scalar) -> Scalar:
# ======================================================================================


def _validate_discrete_grid(names_and_values: dict[str, Any]) -> list[str]:
def _validate_discrete_grid(category_class: type) -> None:
"""Validate the field names and values of the category_class passed to DiscreteGrid.
Args:
names_and_values: A dictionary with the field names as keys and the field
values as values.
Returns:
list[str]: A list of error messages.
category_class: The class with mappings of names to codes.
"""
if not is_dataclass(category_class):
raise GridInitializationError(
"category_class must be a dataclass with scalar int or float fields, "
f"but is {category_class}."
)

names_and_values = _get_field_names_and_values(category_class)

error_messages = []

if not len(names_and_values) > 0:
Expand Down Expand Up @@ -198,7 +193,9 @@ def _validate_discrete_grid(names_and_values: dict[str, Any]) -> list[str]:
f"{set(duplicated_values)}"
)

return error_messages
if error_messages:
msg = format_messages(error_messages)
raise GridInitializationError(msg)


def _get_field_names_and_values(dc: type) -> dict[str, Any]:
Expand Down

0 comments on commit ff42633

Please sign in to comment.