Skip to content

Commit

Permalink
Split user_input module into grids and model module
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Jul 18, 2024
1 parent 8ce39ee commit e54755f
Show file tree
Hide file tree
Showing 8 changed files with 187 additions and 159 deletions.
22 changes: 11 additions & 11 deletions explanations/function_representation.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/lcm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from lcm import mark
from lcm.user_input import DiscreteGrid, LinspaceGrid, LogspaceGrid, Model
from lcm.grids import DiscreteGrid, LinspaceGrid, LogspaceGrid
from lcm.model import Model

__all__ = ["mark", "Model", "LinspaceGrid", "LogspaceGrid", "DiscreteGrid"]
2 changes: 1 addition & 1 deletion src/lcm/create_params_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import pandas as pd
from jax import Array

from lcm.model import Model
from lcm.typing import Params, ScalarUserInput
from lcm.user_input import Model


def create_params_template(
Expand Down
2 changes: 1 addition & 1 deletion src/lcm/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from lcm.discrete_problem import get_solve_discrete_problem
from lcm.dispatchers import productmap
from lcm.logging import get_logger
from lcm.model import Model
from lcm.model_functions import (
get_utility_and_feasibility_function,
)
Expand All @@ -18,7 +19,6 @@
from lcm.simulate import simulate
from lcm.solve_brute import solve
from lcm.state_space import create_state_choice_space
from lcm.user_input import Model


def get_lcm_function(
Expand Down
177 changes: 37 additions & 140 deletions src/lcm/user_input.py → src/lcm/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import dataclasses as dc
from abc import ABC, abstractmethod
from collections.abc import Callable, Collection
from dataclasses import KW_ONLY, InitVar, dataclass, field
from typing import Self, get_args
from collections.abc import Collection
from dataclasses import dataclass, field
from typing import get_args

import jax.numpy as jnp

Expand All @@ -26,54 +26,6 @@ def to_jax(self) -> jnp.ndarray:
"""Convert the grid to a Jax array."""


@dataclass(frozen=True)
class Model:
"""A user model which can be processed into an internal model.
Attributes:
description: Description of the model.
n_periods: Number of periods in the model.
functions: Dictionary of user provided functions that define the functional
relationships between model variables. It must include at least a function
called 'utility'.
choices: Dictionary of user provided choices.
states: Dictionary of user provided states.
"""

description: str | None = None
_: KW_ONLY
n_periods: int
functions: dict[str, Callable] = field(default_factory=dict)
choices: dict[str, Grid] = field(default_factory=dict)
states: dict[str, Grid] = field(default_factory=dict)
_skip_checks: InitVar[bool] = False

def __post_init__(self, _skip_checks: bool) -> None:
if _skip_checks:
return

type_errors = _validate_model_attribute_types(self)
if type_errors:
raise LcmModelInitializationError(_format_errors(type_errors))

logical_errors = _validate_logical_consistency_model(self)
if logical_errors:
raise LcmModelInitializationError(_format_errors(logical_errors))

def replace(self, **kwargs) -> "Model":
"""Replace the attributes of the model.
Args:
**kwargs: Keyword arguments to replace the attributes of the model.
Returns:
A new model with the replaced attributes.
"""
return dc.replace(self, **kwargs)


@dataclass(frozen=True)
class DiscreteGrid(Grid):
"""A grid of discrete values.
Expand All @@ -95,13 +47,13 @@ def __post_init__(self) -> None:

errors = _validate_discrete_grid(self.options)
if errors:
raise LcmGridInitializationError(_format_errors(errors))
raise LcmGridInitializationError(format_errors(errors))

def to_jax(self) -> jnp.ndarray:
"""Convert the grid to a Jax array."""
return jnp.array(list(self.options))

def replace(self, options: Collection[ScalarUserInput]) -> Self:
def replace(self, options: Collection[ScalarUserInput]) -> "DiscreteGrid":
"""Replace the grid with new values.
Args:
Expand Down Expand Up @@ -130,7 +82,7 @@ def __post_init__(self) -> None:
n_points=self.n_points,
)
if errors:
raise LcmGridInitializationError(_format_errors(errors))
raise LcmGridInitializationError(format_errors(errors))

@property
def info(self) -> ContinuousGridInfo:
Expand All @@ -149,21 +101,6 @@ def to_jax(self) -> jnp.ndarray:
n_points=self.n_points,
)

def replace(self, **kwargs) -> Self:
"""Replace the grid with new values.
Args:
**kwargs:
- start: The new start value of the grid.
- stop: The new stop value of the grid.
- n_points: The new number of points in the grid.
Returns:
The updated grid.
"""
return dc.replace(self, **kwargs)


class LinspaceGrid(ContinuousGrid):
"""A linear grid of continuous values.
Expand All @@ -181,6 +118,21 @@ class LinspaceGrid(ContinuousGrid):

kind: ContinuousGridType = "linspace"

def replace(self, **kwargs) -> "LinspaceGrid":
"""Replace the grid with new values.
Args:
**kwargs:
- start: The new start value of the grid.
- stop: The new stop value of the grid.
- n_points: The new number of points in the grid.
Returns:
The updated grid.
"""
return dc.replace(self, **kwargs)


class LogspaceGrid(ContinuousGrid):
"""A logarithmic grid of continuous values.
Expand All @@ -198,21 +150,32 @@ class LogspaceGrid(ContinuousGrid):

kind: ContinuousGridType = "logspace"

def replace(self, **kwargs) -> "LogspaceGrid":
"""Replace the grid with new values.
Args:
**kwargs:
- start: The new start value of the grid.
- stop: The new stop value of the grid.
- n_points: The new number of points in the grid.
Returns:
The updated grid.
"""
return dc.replace(self, **kwargs)


# ======================================================================================
# Validate user input
# ======================================================================================


class LcmModelInitializationError(Exception):
"""Raised when there is an error in the model initialization."""


class LcmGridInitializationError(Exception):
"""Raised when there is an error in the grid initialization."""


def _format_errors(errors: list[str]) -> str:
def format_errors(errors: list[str]) -> str:
"""Convert list of error messages into a single string.
If list is empty, returns the empty string.
Expand All @@ -228,72 +191,6 @@ def _format_errors(errors: list[str]) -> str:
return formatted


# Model
# ======================================================================================


def _validate_model_attribute_types(model: Model) -> list[str]:
"""Validate the types of the model attributes."""
error_messages = []

# Validate types of states and choices
# ----------------------------------------------------------------------------------
for attr_name in ("choices", "states"):
attr = getattr(model, attr_name)
if not isinstance(attr, dict):
error_messages.append(f"{attr_name} must be a dictionary.")
else:
for k, v in attr.items():
if not isinstance(k, str):
error_messages.append(f"{attr_name} key {k} must be a string.")
if not isinstance(v, Grid):
error_messages.append(f"{attr_name} value {v} must be a LCM grid.")

# Validate types of functions
# ----------------------------------------------------------------------------------
if not isinstance(model.functions, dict):
error_messages.append("functions must be a dictionary.")
else:
for k, v in model.functions.items():
if not isinstance(k, str):
error_messages.append(f"functions key {k} must be a string.")
if not callable(v):
error_messages.append(f"functions value {v} must be a callable.")

return error_messages


def _validate_logical_consistency_model(model: Model) -> list[str]:
"""Validate the logical consistency of the model."""
error_messages = []

if model.n_periods < 1:
error_messages.append("Number of periods must be a positive integer.")

if "utility" not in model.functions:
error_messages.append(
"Utility function is not defined. LCM expects a function called 'utility'"
"in the functions dictionary.",
)

if states_without_next_func := [
state for state in model.states if f"next_{state}" not in model.functions
]:
error_messages.append(
"Each state must have a corresponding next state function. For the "
"following states, no next state function was found: "
f"{states_without_next_func}.",
)

if states_and_choices_overlap := set(model.states) & set(model.choices):
error_messages.append(
"States and choices cannot have overlapping names. The following names "
f"are used in both states and choices: {states_and_choices_overlap}.",
)

return error_messages


# Discrete grid
# ======================================================================================

Expand Down
Loading

0 comments on commit e54755f

Please sign in to comment.