diff --git a/src/lcm/grids.py b/src/lcm/grids.py index fe7844f..acc0c0f 100644 --- a/src/lcm/grids.py +++ b/src/lcm/grids.py @@ -1,10 +1,8 @@ """Collection of classes that are used by the user to define the model and grids.""" -import dataclasses as dc from abc import ABC, abstractmethod from collections.abc import Collection from dataclasses import dataclass, field -from typing import NotRequired, TypedDict, cast import jax.numpy as jnp @@ -54,18 +52,6 @@ def to_jax(self) -> jnp.ndarray: """Convert the grid to a Jax array.""" return jnp.array(list(self.options)) - def replace(self, options: Collection[int | float]) -> "DiscreteGrid": - """Replace the grid with new values. - - Args: - options: The new options in the grid. - - Returns: - The updated grid. - - """ - return dc.replace(self, options=options) - @dataclass(frozen=True) class ContinuousGrid(Grid): @@ -104,14 +90,6 @@ def to_jax(self) -> jnp.ndarray: ) -class ContinuousGridReplacements(TypedDict): - """Dictionary of arguments that can be replaced using the `replace` method.""" - - start: NotRequired[int | float] - stop: NotRequired[int | float] - n_points: NotRequired[int] - - class LinspaceGrid(ContinuousGrid): """A linear grid of continuous values. @@ -128,28 +106,6 @@ class LinspaceGrid(ContinuousGrid): kind: ContinuousGridType = "linspace" - def replace( - self, - start: float | None = None, - stop: float | None = None, - n_points: int | None = None, - ) -> "LinspaceGrid": - """Replace the grid with new values. - - Args: - start: The new start value of the grid. - stop: The new stop value of the grid. - n_points: The new number of points in the grid. - - Returns: - The updated grid. - - """ - replacements = {"start": start, "stop": stop, "n_points": n_points} - replacements = {k: v for k, v in replacements.items() if v is not None} - kwargs = cast(ContinuousGridReplacements, replacements) - return dc.replace(self, **kwargs) - class LogspaceGrid(ContinuousGrid): """A logarithmic grid of continuous values. @@ -167,28 +123,6 @@ class LogspaceGrid(ContinuousGrid): kind: ContinuousGridType = "logspace" - def replace( - self, - start: float | None = None, - stop: float | None = None, - n_points: int | None = None, - ) -> "LogspaceGrid": - """Replace the grid with new values. - - Args: - start: The new start value of the grid. - stop: The new stop value of the grid. - n_points: The new number of points in the grid. - - Returns: - The updated grid. - - """ - replacements = {"start": start, "stop": stop, "n_points": n_points} - replacements = {k: v for k, v in replacements.items() if v is not None} - kwargs = cast(ContinuousGridReplacements, replacements) - return dc.replace(self, **kwargs) - # ====================================================================================== # Validate user input diff --git a/tests/test_solution_on_toy_model.py b/tests/test_solution_on_toy_model.py index 197dd04..3da3b00 100644 --- a/tests/test_solution_on_toy_model.py +++ b/tests/test_solution_on_toy_model.py @@ -1,6 +1,7 @@ """Test analytical solution and simulation with only discrete choices.""" from copy import deepcopy +from dataclasses import replace import jax.numpy as jnp import numpy as np @@ -312,8 +313,9 @@ def analytical_simulate_stochastic(initial_wealth, initial_health, health_1, par def test_deterministic_solve(beta, n_wealth_points): # Update model # ================================================================================== - model = deepcopy(DETERMINISTIC_MODEL) - model.states["wealth"] = model.states["wealth"].replace(n_points=n_wealth_points) + new_states = DETERMINISTIC_MODEL.states + new_states["wealth"] = replace(new_states["wealth"], n_points=n_wealth_points) + model = DETERMINISTIC_MODEL.replace(states=new_states) # Solve model using LCM # ================================================================================== @@ -345,10 +347,9 @@ def test_deterministic_solve(beta, n_wealth_points): def test_deterministic_simulate(beta, n_wealth_points): # Update model # ================================================================================== - model = deepcopy(DETERMINISTIC_MODEL) - model.states["wealth"] = model.states["wealth"].replace( - n_points=n_wealth_points, - ) + new_states = DETERMINISTIC_MODEL.states + new_states["wealth"] = replace(new_states["wealth"], n_points=n_wealth_points) + model = DETERMINISTIC_MODEL.replace(states=new_states) # Simulate model using LCM # ================================================================================== @@ -382,14 +383,11 @@ def test_deterministic_simulate(beta, n_wealth_points): @pytest.mark.parametrize("n_wealth_points", [100, 1_000]) @pytest.mark.parametrize("health_transition", HEALTH_TRANSITION) def test_stochastic_solve(beta, n_wealth_points, health_transition): - beta = 0.9 - n_wealth_points = 100 # Update model # ================================================================================== - model = deepcopy(STOCHASTIC_MODEL) - model.states["wealth"] = model.states["wealth"].replace( - n_points=n_wealth_points, - ) + new_states = STOCHASTIC_MODEL.states + new_states["wealth"] = replace(new_states["wealth"], n_points=n_wealth_points) + model = STOCHASTIC_MODEL.replace(states=new_states) # Solve model using LCM # ================================================================================== @@ -435,10 +433,9 @@ def test_stochastic_solve(beta, n_wealth_points, health_transition): def test_stochastic_simulate(beta, n_wealth_points, health_transition): # Update model # ================================================================================== - model = deepcopy(STOCHASTIC_MODEL) - model.states["wealth"] = model.states["wealth"].replace( - n_points=n_wealth_points, - ) + new_states = STOCHASTIC_MODEL.states + new_states["wealth"] = replace(new_states["wealth"], n_points=n_wealth_points) + model = STOCHASTIC_MODEL.replace(states=new_states) # Simulate model using LCM # ==================================================================================