Skip to content

Commit

Permalink
Solve #87
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 19, 2024
1 parent 770e56f commit 4d30bf4
Show file tree
Hide file tree
Showing 18 changed files with 335 additions and 169 deletions.
13 changes: 12 additions & 1 deletion examples/long_running.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Example specification for a consumption-savings model with health and exercise."""

from dataclasses import dataclass

import jax.numpy as jnp

from lcm import DiscreteGrid, LinspaceGrid, Model
Expand All @@ -9,6 +11,15 @@
# ======================================================================================


# --------------------------------------------------------------------------------------
# Categorical variables
# --------------------------------------------------------------------------------------
@dataclass
class WorkingState:
retired: int = 0
working: int = 1


# --------------------------------------------------------------------------------------
# Utility function
# --------------------------------------------------------------------------------------
Expand Down Expand Up @@ -67,7 +78,7 @@ def consumption_constraint(consumption, wealth, labor_income):
"age": age,
},
choices={
"working": DiscreteGrid([0, 1]),
"working": DiscreteGrid(WorkingState),
"consumption": LinspaceGrid(
start=1,
stop=100,
Expand Down
12 changes: 10 additions & 2 deletions explanations/dispatchers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"\n",
"import jax.numpy as jnp\n",
"import pytest\n",
"from jax import vmap\n",
Expand Down Expand Up @@ -277,6 +279,12 @@
"from lcm import DiscreteGrid, LinspaceGrid, Model\n",
"\n",
"\n",
"@dataclass\n",
"class RetirementStatus:\n",
" working: int = 0\n",
" retired: int = 1\n",
"\n",
"\n",
"def utility(consumption, retirement, lagged_retirement, wealth):\n",
" working = 1 - retirement\n",
" retirement_habit = lagged_retirement * wealth\n",
Expand All @@ -296,11 +304,11 @@
" },\n",
" n_periods=1,\n",
" choices={\n",
" \"retirement\": DiscreteGrid([0, 1]),\n",
" \"retirement\": DiscreteGrid(RetirementStatus),\n",
" \"consumption\": LinspaceGrid(start=1, stop=2, n_points=2),\n",
" },\n",
" states={\n",
" \"lagged_retirement\": DiscreteGrid([0, 1]),\n",
" \"lagged_retirement\": DiscreteGrid(RetirementStatus),\n",
" \"wealth\": LinspaceGrid(start=1, stop=4, n_points=4),\n",
" },\n",
")"
Expand Down
10 changes: 9 additions & 1 deletion explanations/function_representation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,19 @@
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"\n",
"import jax.numpy as jnp\n",
"\n",
"from lcm import DiscreteGrid, LinspaceGrid, Model\n",
"\n",
"\n",
"@dataclass\n",
"class RetirementStatus:\n",
" working: int = 0\n",
" retired: int = 1\n",
"\n",
"\n",
"def utility(consumption, working, disutility_of_work):\n",
" return jnp.log(consumption) - disutility_of_work * working\n",
"\n",
Expand Down Expand Up @@ -125,7 +133,7 @@
" \"age\": age,\n",
" },\n",
" choices={\n",
" \"retirement\": DiscreteGrid([0, 1]),\n",
" \"retirement\": DiscreteGrid(RetirementStatus),\n",
" \"consumption\": LinspaceGrid(start=1, stop=400, n_points=20),\n",
" },\n",
" states={\n",
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,10 @@ extend-ignore = [
[tool.ruff.lint.per-file-ignores]
"docs/source/conf.py" = ["E501", "ERA001", "DTZ005"]
"tests/*" = ["PLR2004", "D101"]
"examples/*" = ["INP001"]
"examples/*" = ["INP001", "D101"]
"explanations/*" = ["INP001", "B018", "T201", "E402", "PD008"]
"scripts/*" = ["INP001", "D101", "RET503"]
"**/*.ipynb" = ["FBT003", "E402"]
"**/*.ipynb" = ["FBT003", "E402", "D101"]

[tool.ruff.lint.pydocstyle]
convention = "google"
Expand Down
2 changes: 1 addition & 1 deletion src/lcm/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def get_lcm_function(
elif targets == "solve_and_simulate":
_target = partial(simulate_model, solve_model=solve_model)

user_params = _mod.converter.params_from_internal(_mod.params)
user_params = _mod.converter.internal_to_params(_mod.params)
return cast(Callable, _target), user_params


Expand Down
128 changes: 80 additions & 48 deletions src/lcm/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,63 +16,65 @@ class Grid(ABC):
"""LCM Grid base class."""

@abstractmethod
def to_jax(self) -> jnp.ndarray:
def to_jax(self) -> Array:
"""Convert the grid to a Jax array."""


@dataclass(frozen=True)
class DiscreteGrid(Grid):
"""A grid of discrete values.
"""A class representing a discrete grid.
Args:
category_class (type): The category class representing the grid categories. Must
be a dataclass with fields that have unique scalar int or float values.
Attributes:
options: The options in the grid. Must be an iterable of scalar int or float
values.
categories: The list of category names.
codes: The list of category codes.
Raises:
GridInitializationError: If the `category_class` is not a dataclass with scalar
int or float fields.
"""

options: type
def __init__(self, category_class: type) -> None:
"""Initialize the DiscreteGrid.
def __post_init__(self) -> None:
if not is_dataclass(self.options):
Args:
category_class (type): The category class representing the grid categories.
Must be a dataclass with fields that have unique scalar int or float
values.
"""
if not is_dataclass(category_class):
raise GridInitializationError(
"options must be a dataclass with scalar int or float fields, but is "
f"{self.options}."
"category_class must be a dataclass with scalar int or float fields, "
f"but is {category_class}."
)

errors = _validate_discrete_grid(self.options)
names_and_values = _get_field_names_and_values(category_class)

errors = _validate_discrete_grid(names_and_values)
if errors:
msg = format_messages(errors)
raise GridInitializationError(msg)

def to_jax(self) -> Array:
"""Convert the grid to a Jax array."""
return jnp.array(_get_fields(self.options))


def _get_fields(dc: type) -> list[Any]:
"""Get the fields of a dataclass.
Args:
dc: The dataclass to get the fields of.
self.__categories = list(names_and_values.keys())
self.__codes = list(names_and_values.values())

Returns:
list[Any]: The fields of the dataclass.
@property
def categories(self) -> list[str]:
"""Get the list of category names."""
return self.__categories

Raises:
GridInitializationError: If the fields of the dataclass do not have default
values, or the instantiated dataclass does not have all fields. None values
are treated as if they do not exist.
@property
def codes(self) -> list[int | float]:
"""Get the list of category codes."""
return self.__codes

"""
_fields = {field.name: getattr(dc, field.name, None) for field in fields(dc)}
fields_without_defaults = [name for name, value in _fields.items() if value is None]
if fields_without_defaults:
raise GridInitializationError(
f"To use a DiscreteGrid, all fields of the options dataclass must have "
f"default values. The following fields do not have default values: "
f"{fields_without_defaults}."
)
return list(_fields.values())
def to_jax(self) -> Array:
"""Convert the grid to a Jax array."""
return jnp.array(self.codes)


@dataclass(frozen=True, kw_only=True)
Expand Down Expand Up @@ -157,32 +159,62 @@ def get_coordinate(self, value: Scalar) -> Scalar:
# ======================================================================================


def _validate_discrete_grid(options: type) -> list[str]:
"""Validate the discrete grid options.
def _validate_discrete_grid(names_and_values: dict[str, Any]) -> list[str]:
"""Validate the field names and values of the category_class passed to DiscreteGrid.
Args:
options: The user options to validate in form of a dataclass.
names_and_values: A dictionary with the field names as keys and the field
values as values.
Returns:
list[str]: A list of error messages.
"""
values = _get_fields(options)

error_messages = []

if not len(values) > 0:
error_messages.append("options must contain at least one element")
if not len(names_and_values) > 0:
error_messages.append(
"category_class passed to DiscreteGrid must have at least one field"
)

if not all(isinstance(value, int | float) for value in values):
error_messages.append("options must contain only scalar int or float values")
names_with_non_numerical_values = [
name
for name, value in names_and_values.items()
if not isinstance(value, int | float)
]
if names_with_non_numerical_values:
error_messages.append(
"Field values of the category_class passed to DiscreteGrid can only be "
"scalar int or float values. The values to the following fields are not: "
f"{names_with_non_numerical_values}"
)

if len(values) != len(set(values)):
error_messages.append("options must contain unique values")
values = list(names_and_values.values())
duplicated_values = [v for v in values if values.count(v) > 1]
if duplicated_values:
error_messages.append(
"Field values of the category_class passed to DiscreteGrid must be unique. "
"The following values are duplicated: "
f"{set(duplicated_values)}"
)

return error_messages


def _get_field_names_and_values(dc: type) -> dict[str, Any]:
"""Get the fields of a dataclass.
Args:
dc: The dataclass to get the fields of.
Returns:
A dictionary with the field names as keys and the field values as values. If
no value is provided for a field, the value is set to None.
"""
return {field.name: getattr(dc, field.name, None) for field in fields(dc)}


def _validate_continuous_grid(
start: float,
stop: float,
Expand Down
Loading

0 comments on commit 4d30bf4

Please sign in to comment.