Skip to content

Commit

Permalink
First step: Only allow dataclass as input to discrete grid
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 19, 2024
1 parent 000bc2a commit 770e56f
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 17 deletions.
52 changes: 40 additions & 12 deletions src/lcm/grids.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Collection of classes that are used by the user to define the model and grids."""

from abc import ABC, abstractmethod
from collections.abc import Collection
from dataclasses import dataclass
from dataclasses import dataclass, fields, is_dataclass
from typing import Any

import jax.numpy as jnp
from jax import Array
Expand Down Expand Up @@ -30,13 +30,13 @@ class DiscreteGrid(Grid):
"""

options: Collection[int | float]
options: type

def __post_init__(self) -> None:
if not isinstance(self.options, Collection):
if not is_dataclass(self.options):
raise GridInitializationError(
"options must be a collection of scalar int or float values, e.g., a ",
"list or tuple",
"options must be a dataclass with scalar int or float fields, but is "
f"{self.options}."
)

errors = _validate_discrete_grid(self.options)
Expand All @@ -46,7 +46,33 @@ def __post_init__(self) -> None:

def to_jax(self) -> Array:
"""Convert the grid to a Jax array."""
return jnp.array(list(self.options))
return jnp.array(_get_fields(self.options))


def _get_fields(dc: type) -> list[Any]:
"""Get the fields of a dataclass.
Args:
dc: The dataclass to get the fields of.
Returns:
list[Any]: The fields of the dataclass.
Raises:
GridInitializationError: If the fields of the dataclass do not have default
values, or the instantiated dataclass does not have all fields. None values
are treated as if they do not exist.
"""
_fields = {field.name: getattr(dc, field.name, None) for field in fields(dc)}
fields_without_defaults = [name for name, value in _fields.items() if value is None]
if fields_without_defaults:
raise GridInitializationError(
f"To use a DiscreteGrid, all fields of the options dataclass must have "
f"default values. The following fields do not have default values: "
f"{fields_without_defaults}."
)
return list(_fields.values())


@dataclass(frozen=True, kw_only=True)
Expand Down Expand Up @@ -131,25 +157,27 @@ def get_coordinate(self, value: Scalar) -> Scalar:
# ======================================================================================


def _validate_discrete_grid(options: Collection[int | float]) -> list[str]:
def _validate_discrete_grid(options: type) -> list[str]:
"""Validate the discrete grid options.
Args:
options: The user options to validate.
options: The user options to validate in form of a dataclass.
Returns:
list[str]: A list of error messages.
"""
values = _get_fields(options)

error_messages = []

if not len(options) > 0:
if not len(values) > 0:
error_messages.append("options must contain at least one element")

if not all(isinstance(option, int | float) for option in options):
if not all(isinstance(value, int | float) for value in values):
error_messages.append("options must contain only scalar int or float values")

if len(options) != len(set(options)):
if len(values) != len(set(values)):
error_messages.append("options must contain unique values")

return error_messages
Expand Down
41 changes: 36 additions & 5 deletions tests/test_grids.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from dataclasses import make_dataclass

import numpy as np
import pytest

Expand All @@ -6,27 +8,54 @@
DiscreteGrid,
LinspaceGrid,
LogspaceGrid,
_get_fields,
_validate_continuous_grid,
_validate_discrete_grid,
)


def test_validate_discrete_grid_empty():
assert _validate_discrete_grid([]) == ["options must contain at least one element"]
options = make_dataclass("Options", [])
assert _validate_discrete_grid(options) == [
"options must contain at least one element"
]


def test_validate_discrete_grid_non_scalar_input():
assert _validate_discrete_grid([1, "a"]) == [
options = make_dataclass("Options", [("a", int, 1), ("b", str, "wrong_type")])
assert _validate_discrete_grid(options) == [
"options must contain only scalar int or float values",
]


def test_validate_discrete_grid_non_unique():
assert _validate_discrete_grid([1, 2, 2]) == [
options = make_dataclass("Options", [("a", int, 1), ("b", int, 2), ("c", int, 2)])
assert _validate_discrete_grid(options) == [
"options must contain unique values",
]


def test_get_fields_with_defaults():
options = make_dataclass("Options", [("a", int, 1), ("b", int, 2), ("c", int, 3)])
assert _get_fields(options) == [1, 2, 3]


def test_get_fields_instance():
options = make_dataclass("Options", [("a", int), ("b", int)])
assert _get_fields(options(a=1, b=2)) == [1, 2]


def test_get_fields_empty():
options = make_dataclass("Options", [])
assert _get_fields(options) == []


def test_get_fields_no_defaults():
options = make_dataclass("Options", [("a", int), ("b", int)])
with pytest.raises(GridInitializationError, match="To use a DiscreteGrid"):
_get_fields(options)


def test_validate_continuous_grid_invalid_start():
assert _validate_continuous_grid("a", 1, 10) == [
"start must be a scalar int or float value"
Expand Down Expand Up @@ -66,7 +95,8 @@ def test_logspace_grid_creation():


def test_discrete_grid_creation():
grid = DiscreteGrid(options=[0, 1, 2])
options = make_dataclass("Options", [("a", int, 0), ("b", int, 1), ("c", int, 2)])
grid = DiscreteGrid(options)
assert np.allclose(grid.to_jax(), np.arange(3))


Expand All @@ -81,8 +111,9 @@ def test_logspace_grid_invalid_start():


def test_discrete_grid_invalid_options():
options = make_dataclass("Options", [("a", int, 1), ("b", str, "wrong_type")])
with pytest.raises(
GridInitializationError,
match="options must contain only scalar int or float values",
):
DiscreteGrid(options=[1, "a"])
DiscreteGrid(options)

0 comments on commit 770e56f

Please sign in to comment.