diff --git a/src/lcm/grids.py b/src/lcm/grids.py index 21941b0..e390c2e 100644 --- a/src/lcm/grids.py +++ b/src/lcm/grids.py @@ -1,8 +1,8 @@ """Collection of classes that are used by the user to define the model and grids.""" from abc import ABC, abstractmethod -from collections.abc import Collection -from dataclasses import dataclass +from dataclasses import dataclass, fields, is_dataclass +from typing import Any import jax.numpy as jnp from jax import Array @@ -30,13 +30,13 @@ class DiscreteGrid(Grid): """ - options: Collection[int | float] + options: type def __post_init__(self) -> None: - if not isinstance(self.options, Collection): + if not is_dataclass(self.options): raise GridInitializationError( - "options must be a collection of scalar int or float values, e.g., a ", - "list or tuple", + "options must be a dataclass with scalar int or float fields, but is " + f"{self.options}." ) errors = _validate_discrete_grid(self.options) @@ -46,7 +46,33 @@ def __post_init__(self) -> None: def to_jax(self) -> Array: """Convert the grid to a Jax array.""" - return jnp.array(list(self.options)) + return jnp.array(_get_fields(self.options)) + + +def _get_fields(dc: type) -> list[Any]: + """Get the fields of a dataclass. + + Args: + dc: The dataclass to get the fields of. + + Returns: + list[Any]: The fields of the dataclass. + + Raises: + GridInitializationError: If the fields of the dataclass do not have default + values, or the instantiated dataclass does not have all fields. None values + are treated as if they do not exist. + + """ + _fields = {field.name: getattr(dc, field.name, None) for field in fields(dc)} + fields_without_defaults = [name for name, value in _fields.items() if value is None] + if fields_without_defaults: + raise GridInitializationError( + f"To use a DiscreteGrid, all fields of the options dataclass must have " + f"default values. The following fields do not have default values: " + f"{fields_without_defaults}." + ) + return list(_fields.values()) @dataclass(frozen=True, kw_only=True) @@ -131,25 +157,27 @@ def get_coordinate(self, value: Scalar) -> Scalar: # ====================================================================================== -def _validate_discrete_grid(options: Collection[int | float]) -> list[str]: +def _validate_discrete_grid(options: type) -> list[str]: """Validate the discrete grid options. Args: - options: The user options to validate. + options: The user options to validate in form of a dataclass. Returns: list[str]: A list of error messages. """ + values = _get_fields(options) + error_messages = [] - if not len(options) > 0: + if not len(values) > 0: error_messages.append("options must contain at least one element") - if not all(isinstance(option, int | float) for option in options): + if not all(isinstance(value, int | float) for value in values): error_messages.append("options must contain only scalar int or float values") - if len(options) != len(set(options)): + if len(values) != len(set(values)): error_messages.append("options must contain unique values") return error_messages diff --git a/tests/test_grids.py b/tests/test_grids.py index a6b8b4b..61bf614 100644 --- a/tests/test_grids.py +++ b/tests/test_grids.py @@ -1,3 +1,5 @@ +from dataclasses import make_dataclass + import numpy as np import pytest @@ -6,27 +8,54 @@ DiscreteGrid, LinspaceGrid, LogspaceGrid, + _get_fields, _validate_continuous_grid, _validate_discrete_grid, ) def test_validate_discrete_grid_empty(): - assert _validate_discrete_grid([]) == ["options must contain at least one element"] + options = make_dataclass("Options", []) + assert _validate_discrete_grid(options) == [ + "options must contain at least one element" + ] def test_validate_discrete_grid_non_scalar_input(): - assert _validate_discrete_grid([1, "a"]) == [ + options = make_dataclass("Options", [("a", int, 1), ("b", str, "wrong_type")]) + assert _validate_discrete_grid(options) == [ "options must contain only scalar int or float values", ] def test_validate_discrete_grid_non_unique(): - assert _validate_discrete_grid([1, 2, 2]) == [ + options = make_dataclass("Options", [("a", int, 1), ("b", int, 2), ("c", int, 2)]) + assert _validate_discrete_grid(options) == [ "options must contain unique values", ] +def test_get_fields_with_defaults(): + options = make_dataclass("Options", [("a", int, 1), ("b", int, 2), ("c", int, 3)]) + assert _get_fields(options) == [1, 2, 3] + + +def test_get_fields_instance(): + options = make_dataclass("Options", [("a", int), ("b", int)]) + assert _get_fields(options(a=1, b=2)) == [1, 2] + + +def test_get_fields_empty(): + options = make_dataclass("Options", []) + assert _get_fields(options) == [] + + +def test_get_fields_no_defaults(): + options = make_dataclass("Options", [("a", int), ("b", int)]) + with pytest.raises(GridInitializationError, match="To use a DiscreteGrid"): + _get_fields(options) + + def test_validate_continuous_grid_invalid_start(): assert _validate_continuous_grid("a", 1, 10) == [ "start must be a scalar int or float value" @@ -66,7 +95,8 @@ def test_logspace_grid_creation(): def test_discrete_grid_creation(): - grid = DiscreteGrid(options=[0, 1, 2]) + options = make_dataclass("Options", [("a", int, 0), ("b", int, 1), ("c", int, 2)]) + grid = DiscreteGrid(options) assert np.allclose(grid.to_jax(), np.arange(3)) @@ -81,8 +111,9 @@ def test_logspace_grid_invalid_start(): def test_discrete_grid_invalid_options(): + options = make_dataclass("Options", [("a", int, 1), ("b", str, "wrong_type")]) with pytest.raises( GridInitializationError, match="options must contain only scalar int or float values", ): - DiscreteGrid(options=[1, "a"]) + DiscreteGrid(options)