Skip to content

Commit

Permalink
Remove replace method from grids
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 5, 2024
1 parent 048ce87 commit 316810c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 82 deletions.
66 changes: 0 additions & 66 deletions src/lcm/grids.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""Collection of classes that are used by the user to define the model and grids."""

import dataclasses as dc
from abc import ABC, abstractmethod
from collections.abc import Collection
from dataclasses import dataclass, field
from typing import NotRequired, TypedDict, cast

import jax.numpy as jnp

Expand Down Expand Up @@ -54,18 +52,6 @@ def to_jax(self) -> jnp.ndarray:
"""Convert the grid to a Jax array."""
return jnp.array(list(self.options))

def replace(self, options: Collection[int | float]) -> "DiscreteGrid":
"""Replace the grid with new values.
Args:
options: The new options in the grid.
Returns:
The updated grid.
"""
return dc.replace(self, options=options)


@dataclass(frozen=True)
class ContinuousGrid(Grid):
Expand Down Expand Up @@ -104,14 +90,6 @@ def to_jax(self) -> jnp.ndarray:
)


class ContinuousGridReplacements(TypedDict):
"""Dictionary of arguments that can be replaced using the `replace` method."""

start: NotRequired[int | float]
stop: NotRequired[int | float]
n_points: NotRequired[int]


class LinspaceGrid(ContinuousGrid):
"""A linear grid of continuous values.
Expand All @@ -128,28 +106,6 @@ class LinspaceGrid(ContinuousGrid):

kind: ContinuousGridType = "linspace"

def replace(
self,
start: float | None = None,
stop: float | None = None,
n_points: int | None = None,
) -> "LinspaceGrid":
"""Replace the grid with new values.
Args:
start: The new start value of the grid.
stop: The new stop value of the grid.
n_points: The new number of points in the grid.
Returns:
The updated grid.
"""
replacements = {"start": start, "stop": stop, "n_points": n_points}
replacements = {k: v for k, v in replacements.items() if v is not None}
kwargs = cast(ContinuousGridReplacements, replacements)
return dc.replace(self, **kwargs)


class LogspaceGrid(ContinuousGrid):
"""A logarithmic grid of continuous values.
Expand All @@ -167,28 +123,6 @@ class LogspaceGrid(ContinuousGrid):

kind: ContinuousGridType = "logspace"

def replace(
self,
start: float | None = None,
stop: float | None = None,
n_points: int | None = None,
) -> "LogspaceGrid":
"""Replace the grid with new values.
Args:
start: The new start value of the grid.
stop: The new stop value of the grid.
n_points: The new number of points in the grid.
Returns:
The updated grid.
"""
replacements = {"start": start, "stop": stop, "n_points": n_points}
replacements = {k: v for k, v in replacements.items() if v is not None}
kwargs = cast(ContinuousGridReplacements, replacements)
return dc.replace(self, **kwargs)


# ======================================================================================
# Validate user input
Expand Down
29 changes: 13 additions & 16 deletions tests/test_solution_on_toy_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test analytical solution and simulation with only discrete choices."""

from copy import deepcopy
from dataclasses import replace

import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -312,8 +313,9 @@ def analytical_simulate_stochastic(initial_wealth, initial_health, health_1, par
def test_deterministic_solve(beta, n_wealth_points):
# Update model
# ==================================================================================
model = deepcopy(DETERMINISTIC_MODEL)
model.states["wealth"] = model.states["wealth"].replace(n_points=n_wealth_points)
new_states = DETERMINISTIC_MODEL.states
new_states["wealth"] = replace(new_states["wealth"], n_points=n_wealth_points)
model = DETERMINISTIC_MODEL.replace(states=new_states)

# Solve model using LCM
# ==================================================================================
Expand Down Expand Up @@ -345,10 +347,9 @@ def test_deterministic_solve(beta, n_wealth_points):
def test_deterministic_simulate(beta, n_wealth_points):
# Update model
# ==================================================================================
model = deepcopy(DETERMINISTIC_MODEL)
model.states["wealth"] = model.states["wealth"].replace(
n_points=n_wealth_points,
)
new_states = DETERMINISTIC_MODEL.states
new_states["wealth"] = replace(new_states["wealth"], n_points=n_wealth_points)
model = DETERMINISTIC_MODEL.replace(states=new_states)

# Simulate model using LCM
# ==================================================================================
Expand Down Expand Up @@ -382,14 +383,11 @@ def test_deterministic_simulate(beta, n_wealth_points):
@pytest.mark.parametrize("n_wealth_points", [100, 1_000])
@pytest.mark.parametrize("health_transition", HEALTH_TRANSITION)
def test_stochastic_solve(beta, n_wealth_points, health_transition):
beta = 0.9
n_wealth_points = 100
# Update model
# ==================================================================================
model = deepcopy(STOCHASTIC_MODEL)
model.states["wealth"] = model.states["wealth"].replace(
n_points=n_wealth_points,
)
new_states = STOCHASTIC_MODEL.states
new_states["wealth"] = replace(new_states["wealth"], n_points=n_wealth_points)
model = STOCHASTIC_MODEL.replace(states=new_states)

# Solve model using LCM
# ==================================================================================
Expand Down Expand Up @@ -435,10 +433,9 @@ def test_stochastic_solve(beta, n_wealth_points, health_transition):
def test_stochastic_simulate(beta, n_wealth_points, health_transition):
# Update model
# ==================================================================================
model = deepcopy(STOCHASTIC_MODEL)
model.states["wealth"] = model.states["wealth"].replace(
n_points=n_wealth_points,
)
new_states = STOCHASTIC_MODEL.states
new_states["wealth"] = replace(new_states["wealth"], n_points=n_wealth_points)
model = STOCHASTIC_MODEL.replace(states=new_states)

# Simulate model using LCM
# ==================================================================================
Expand Down

0 comments on commit 316810c

Please sign in to comment.