Skip to content

Commit

Permalink
Allow for general discrete grids
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Jul 15, 2024
1 parent 88a8a82 commit e124a18
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 59 deletions.
114 changes: 93 additions & 21 deletions src/lcm/process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from dags import get_ancestors
from dags.signature import with_signature
from jax import Array
import jax.numpy as jnp

from lcm.create_params_template import create_params_template
from lcm.functools import all_as_args, all_as_kwargs
from lcm.interfaces import InternalModel
from lcm.typing import Params
from lcm.user_input import (
ContinuousGrid,
DiscreteGrid,
Grid,
Model,
)
Expand Down Expand Up @@ -44,33 +46,103 @@ def process_model(user_model: Model) -> InternalModel:

gridspecs = _get_gridspecs(user_model, variable_info=variable_info)

grids = _get_grids(gridspecs=gridspecs, variable_info=variable_info)
if discrete_vars_to_update := _discrete_vars_with_non_index_options(gridspecs):

params = create_params_template(
user_model,
variable_info=variable_info,
grids=grids,
)
updated_user_model = _update_user_model(
user_model,
gridspecs,
discrete_vars_to_update,
)

functions = _get_functions(
user_model,
function_info=function_info,
variable_info=variable_info,
params=params,
grids=grids,
)
return InternalModel(
grids=grids,
gridspecs=gridspecs,
variable_info=variable_info,
out = process_model(updated_user_model)

else:

grids = _get_grids(gridspecs=gridspecs, variable_info=variable_info)

params = create_params_template(
user_model,
variable_info=variable_info,
grids=grids,
)

functions = _get_functions(
user_model,
function_info=function_info,
variable_info=variable_info,
params=params,
grids=grids,
)
out = InternalModel(
grids=grids,
gridspecs=gridspecs,
variable_info=variable_info,
functions=functions,
function_info=function_info,
params=params,
shocks=user_model.shocks if hasattr(user_model, "shocks") else {},
n_periods=user_model.n_periods,
)

return out


def _discrete_vars_with_non_index_options(gridspecs: dict[str, Grid]) -> list[str]:
vars = []
for name, spec in gridspecs.items():
if isinstance(spec, DiscreteGrid):
if list(spec.options) != list(range(len(spec.options))):
vars.append(name)
return vars


def _update_user_model(
user_model: Model,
gridspecs: dict[str, Grid],
discrete_vars_to_update: list[str],
) -> Model:
"""Update the user model to ensure that discrete variables have index options."""
functions = user_model.functions
states = user_model.states
choices = user_model.choices

for var in discrete_vars_to_update:

index_grid = DiscreteGrid(options=list(range(len(gridspecs[var].options))))

if var in states:
states.pop(var)
states[f"__{var}_index__"] = index_grid
else:
choices.pop(var)
choices[f"__{var}_index__"] = index_grid

functions[var] = _get_index_to_label_func(gridspecs[var].options, name=var)

return user_model.replace(
states=states,
choices=choices,
functions=functions,
function_info=function_info,
params=params,
shocks=user_model.shocks if hasattr(user_model, "shocks") else {},
n_periods=user_model.n_periods,
)


def _get_index_to_label_func(options, name):

options_array = jnp.asarray(options)

arg_name = f"__{name}_index__"

@with_signature(args=[arg_name])
def func(*args, **kwargs):
kwargs = all_as_kwargs(args, kwargs, arg_names=[arg_name])
index = kwargs[arg_name]
return options_array[index]

return func




def _get_function_info(user_model: Model) -> pd.DataFrame:
"""Derive information about functions in the model.
Expand Down
74 changes: 41 additions & 33 deletions src/lcm/user_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,14 @@
from abc import ABC, abstractmethod
from collections.abc import Callable, Collection
from dataclasses import KW_ONLY, InitVar, dataclass, field
from typing import Self, get_args
from typing import get_args

import jax.numpy as jnp

from lcm.grids import linspace, logspace
import lcm.grids as grids_module
from lcm.interfaces import ContinuousGridInfo
from lcm.typing import ContinuousGridType, ScalarUserInput

build_grid_mapping = {
"linspace": linspace,
"logspace": logspace,
}


class Grid(ABC):
"""LCM Grid base class."""
Expand Down Expand Up @@ -99,9 +94,9 @@ def __post_init__(self) -> None:

def to_jax(self) -> jnp.ndarray:
"""Convert the grid to a Jax array."""
return jnp.array(list(self.options))
return jnp.asarray(list(self.options))

def replace(self, options: Collection[ScalarUserInput]) -> Self:
def replace(self, options: Collection[ScalarUserInput]) -> "DiscreteGrid":
"""Replace the grid with new values.
Args:
Expand Down Expand Up @@ -143,26 +138,8 @@ def info(self) -> ContinuousGridInfo:

def to_jax(self) -> jnp.ndarray:
"""Convert the grid to a Jax array."""
return build_grid_mapping[self.kind](
start=self.start,
stop=self.stop,
n_points=self.n_points,
)

def replace(self, **kwargs) -> Self:
"""Replace the grid with new values.
Args:
**kwargs:
- start: The new start value of the grid.
- stop: The new stop value of the grid.
- n_points: The new number of points in the grid.
Returns:
The updated grid.
"""
return dc.replace(self, **kwargs)
space_func = getattr(grids_module, self.kind)
return space_func(start=self.start, stop=self.stop, n_points=self.n_points)


class LinspaceGrid(ContinuousGrid):
Expand All @@ -182,6 +159,22 @@ class LinspaceGrid(ContinuousGrid):
kind: ContinuousGridType = "linspace"


def replace(self, **kwargs) -> "LinspaceGrid":
"""Replace the grid with new values.
Args:
**kwargs:
- start: The new start value of the grid.
- stop: The new stop value of the grid.
- n_points: The new number of points in the grid.
Returns:
The updated grid.
"""
return dc.replace(self, **kwargs)


class LogspaceGrid(ContinuousGrid):
"""A logarithmic grid of continuous values.
Expand All @@ -197,6 +190,21 @@ class LogspaceGrid(ContinuousGrid):
"""

kind: ContinuousGridType = "logspace"

def replace(self, **kwargs) -> "LogspaceGrid":
"""Replace the grid with new values.
Args:
**kwargs:
- start: The new start value of the grid.
- stop: The new stop value of the grid.
- n_points: The new number of points in the grid.
Returns:
The updated grid.
"""
return dc.replace(self, **kwargs)


# ======================================================================================
Expand Down Expand Up @@ -319,10 +327,10 @@ def _validate_discrete_grid(options: Collection[ScalarUserInput]) -> list[str]:
if len(options) != len(set(options)):
error_messages.append("options must contain unique values")

if list(options) != list(range(len(options))):
error_messages.append(
"options must be a list of consecutive integers starting from 0",
)
# if list(options) != list(range(len(options))):
# error_messages.append(
# "options must be a list of consecutive integers starting from 0",
# )

return error_messages

Expand Down
4 changes: 2 additions & 2 deletions tests/test_entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def test_create_compute_conditional_continuation_value_with_discrete_model():
)

val = compute_ccv(
consumption_index=jnp.array([0, 1]),
__consumption_index__=jnp.array([0, 1]),
retirement=1,
wealth=2,
params=params,
Expand Down Expand Up @@ -338,7 +338,7 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model():
)

policy, val = compute_ccv_policy(
consumption_index=jnp.array([0, 1]),
__consumption_index__=jnp.array([0, 1]),
retirement=1,
wealth=2,
params=params,
Expand Down
8 changes: 5 additions & 3 deletions tests/test_models/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,18 @@ def absorbing_retirement_filter(retirement, lagged_retirement):
),
n_periods=3,
functions={
"utility": utility_fully_discrete,
# "utility": utility_fully_discrete,
"utility": utility,
"next_wealth": next_wealth,
"consumption_constraint": consumption_constraint,
"labor_income": labor_income,
"working": working,
"consumption": consumption,
# "consumption": consumption,
},
choices={
"retirement": DiscreteGrid([0, 1]),
"consumption_index": DiscreteGrid([0, 1]),
# "consumption_index": DiscreteGrid([0, 1]),
"consumption": DiscreteGrid([1, 2]),
},
states={
"wealth": LinspaceGrid(
Expand Down

0 comments on commit e124a18

Please sign in to comment.