Skip to content

Commit

Permalink
Add get_coordinate method to grid classes
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 5, 2024
1 parent c5f2a28 commit 58db89d
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 179 deletions.
133 changes: 6 additions & 127 deletions pixi.lock

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ plotly = ">=5.2,<6"
pre-commit = "*"
snakeviz = "*"
memory_profiler = "*"
pixi-kernel = ">=0.3.0,<0.4"

[tool.pixi.target.unix.dependencies]
jax = ">=0.4.20"
Expand Down
19 changes: 7 additions & 12 deletions src/lcm/function_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from jax import Array
from jax.scipy.ndimage import map_coordinates

import lcm.grid_helpers as grids_module
from lcm.functools import all_as_kwargs
from lcm.interfaces import ContinuousGridInfo, ContinuousGridType, SpaceInfo
from lcm.grids import ContinuousGrid
from lcm.interfaces import SpaceInfo
from lcm.typing import MapCoordinatesOptions


Expand Down Expand Up @@ -120,8 +120,7 @@ def get_function_representation(
for var, grid_spec in space_info.interpolation_info.items():
funcs[f"__{var}_coord__"] = _get_coordinate_finder(
in_name=input_prefix + var,
grid_type=grid_spec.kind,
grid_info=grid_spec.info,
grid=grid_spec, # type: ignore[arg-type]
)

# ==============================================================================
Expand Down Expand Up @@ -199,8 +198,7 @@ def lookup_wrapper(*args, **kwargs):

def _get_coordinate_finder(
in_name: str,
grid_type: ContinuousGridType,
grid_info: ContinuousGridInfo,
grid: ContinuousGrid,
) -> Callable[..., Array]:
"""Create a function that translates a value into coordinates on a grid.
Expand All @@ -210,22 +208,19 @@ def _get_coordinate_finder(
Args:
in_name: Name via which the value to be translated into coordinates will be
passed into the resulting function.
grid_type: Type of the grid, e.g. "linspace" or "logspace". The type of grid
must be implemented in lcm.grids.
grid_info: Information on how to build the grid, e.g. start, stop, and n_points.
grid: The continuous grid on which the value is to be translated into
coordinates.
Returns:
callable: A callable with keyword-only argument [in_name] that translates a
value into coordinates on a grid.
"""
raw_func = getattr(grids_module, f"get_{grid_type}_coordinate")
partialled_func = partial(raw_func, **grid_info._asdict())

@with_signature(args=[in_name])
def find_coordinate(*args, **kwargs):
kwargs = all_as_kwargs(args, kwargs, arg_names=[in_name])
return partialled_func(kwargs[in_name])
return grid.get_coordinate(kwargs[in_name])

return find_coordinate

Expand Down
51 changes: 31 additions & 20 deletions src/lcm/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,15 @@

from abc import ABC, abstractmethod
from collections.abc import Collection
from dataclasses import dataclass, field
from dataclasses import dataclass

import jax.numpy as jnp
from jax import Array

from lcm import grid_helpers
from lcm.exceptions import GridInitializationError, format_messages
from lcm.grid_helpers import linspace, logspace
from lcm.interfaces import ContinuousGridInfo
from lcm.typing import ContinuousGridType

build_grid_mapping = {
"linspace": linspace,
"logspace": logspace,
}
from lcm.typing import Scalar


class Grid(ABC):
Expand Down Expand Up @@ -48,16 +44,15 @@ def __post_init__(self) -> None:
msg = format_messages(errors)
raise GridInitializationError(msg)

def to_jax(self) -> jnp.ndarray:
def to_jax(self) -> Array:
"""Convert the grid to a Jax array."""
return jnp.array(list(self.options))


@dataclass(frozen=True)
class ContinuousGrid(Grid):
@dataclass(frozen=True, kw_only=True)
class ContinuousGrid(Grid, ABC):
"""LCM Continuous Grid base class."""

kind: ContinuousGridType = field(init=False, default=None) # type: ignore[arg-type]
start: int | float
stop: int | float
n_points: int
Expand All @@ -81,13 +76,13 @@ def info(self) -> ContinuousGridInfo:
n_points=self.n_points,
)

def to_jax(self) -> jnp.ndarray:
@abstractmethod
def to_jax(self) -> Array:
"""Convert the grid to a Jax array."""
return build_grid_mapping[self.kind](
start=self.start,
stop=self.stop,
n_points=self.n_points,
)

@abstractmethod
def get_coordinate(self, value: Scalar) -> Scalar:
"""Get the generalized coordinate of a value in the grid."""


class LinspaceGrid(ContinuousGrid):
Expand All @@ -104,7 +99,15 @@ class LinspaceGrid(ContinuousGrid):
"""

kind: ContinuousGridType = "linspace"
def to_jax(self) -> Array:
"""Convert the grid to a Jax array."""
return grid_helpers.linspace(self.start, self.stop, self.n_points)

def get_coordinate(self, value: Scalar) -> Scalar:
"""Get the generalized coordinate of a value in the grid."""
return grid_helpers.get_linspace_coordinate(
value, self.start, self.stop, self.n_points
)


class LogspaceGrid(ContinuousGrid):
Expand All @@ -121,7 +124,15 @@ class LogspaceGrid(ContinuousGrid):
"""

kind: ContinuousGridType = "logspace"
def to_jax(self) -> Array:
"""Convert the grid to a Jax array."""
return grid_helpers.logspace(self.start, self.stop, self.n_points)

def get_coordinate(self, value: Scalar) -> Scalar:
"""Get the generalized coordinate of a value in the grid."""
return grid_helpers.get_logspace_coordinate(
value, self.start, self.stop, self.n_points
)


# ======================================================================================
Expand Down
7 changes: 2 additions & 5 deletions tests/test_function_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
)
from lcm.grid_helpers import linspace
from lcm.interfaces import (
ContinuousGridInfo,
IndexerInfo,
SpaceInfo,
)
Expand Down Expand Up @@ -274,8 +273,7 @@ def test_get_lookup_function():
def test_get_coordinate_finder():
find_coordinate = _get_coordinate_finder(
in_name="wealth",
grid_type="linspace",
grid_info=ContinuousGridInfo(start=0, stop=10, n_points=21),
grid=LinspaceGrid(start=0, stop=10, n_points=21),
)

calculated = find_coordinate(wealth=5.75)
Expand Down Expand Up @@ -356,8 +354,7 @@ def test_get_lookup_function_illustrative():
def test_get_coordinate_finder_illustrative():
find_coordinate = _get_coordinate_finder(
in_name="a",
grid_type="linspace",
grid_info=ContinuousGridInfo(start=0, stop=1, n_points=3),
grid=LinspaceGrid(start=0, stop=1, n_points=3),
)

assert find_coordinate(a=0) == 0
Expand Down
27 changes: 13 additions & 14 deletions tests/test_process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from numpy.testing import assert_array_equal
from pandas.testing import assert_frame_equal

import lcm.grid_helpers as grids_module
from lcm import DiscreteGrid, LinspaceGrid
from lcm import DiscreteGrid, LinspaceGrid, grid_helpers
from lcm.mark import StochasticInfo
from lcm.process_model import (
_get_function_info,
Expand Down Expand Up @@ -155,13 +154,13 @@ def test_process_model_iskhakov_et_al_2017():
assert model.gridspecs["lagged_retirement"] == DiscreteGrid([0, 1])

# Grids
func = getattr(grids_module, model.gridspecs["consumption"].kind)
asserted = func(**model.gridspecs["consumption"].info._asdict())
assert (asserted == model.grids["consumption"]).all()
expected = grid_helpers.linspace(
**model_config.choices["consumption"].info._asdict()
)
assert_array_equal(model.grids["consumption"], expected)

func = getattr(grids_module, model.gridspecs["wealth"].kind)
asserted = func(**model.gridspecs["wealth"].info._asdict())
assert (asserted == model.grids["wealth"]).all()
expected = grid_helpers.linspace(**model_config.states["wealth"].info._asdict())
assert_array_equal(model.grids["wealth"], expected)

assert (model.grids["retirement"] == jnp.array([0, 1])).all()
assert (model.grids["lagged_retirement"] == jnp.array([0, 1])).all()
Expand Down Expand Up @@ -214,13 +213,13 @@ def test_process_model():
assert model.gridspecs["retirement"] == DiscreteGrid([0, 1])

# Grids
func = getattr(grids_module, model.gridspecs["consumption"].kind)
asserted = func(**model.gridspecs["consumption"].info._asdict())
assert (asserted == model.grids["consumption"]).all()
expected = grid_helpers.linspace(
**model_config.choices["consumption"].info._asdict()
)
assert_array_equal(model.grids["consumption"], expected)

func = getattr(grids_module, model.gridspecs["wealth"].kind)
asserted = func(**model.gridspecs["wealth"].info._asdict())
assert (asserted == model.grids["wealth"]).all()
expected = grid_helpers.linspace(**model_config.states["wealth"].info._asdict())
assert_array_equal(model.grids["wealth"], expected)

assert (model.grids["retirement"] == jnp.array([0, 1])).all()

Expand Down

0 comments on commit 58db89d

Please sign in to comment.