Skip to content

Commit

Permalink
Implement linear extrapolation (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens authored Sep 26, 2024
1 parent 8643b0e commit 7e8e3f8
Show file tree
Hide file tree
Showing 19 changed files with 320 additions and 121 deletions.
35 changes: 15 additions & 20 deletions explanations/function_representation.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ extend-ignore = [
"explanations/*" = ["INP001", "B018", "T201", "E402", "PD008"]
"scripts/*" = ["INP001", "D101", "RET503"]
"**/*.ipynb" = ["FBT003", "E402", "D101"]
"src/lcm/ndimage.py" = ["A002"] # Argument `input` is shadowing a Python builtin

[tool.ruff.lint.pydocstyle]
convention = "google"
Expand Down
6 changes: 0 additions & 6 deletions src/lcm/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
get_utility_and_feasibility_function,
)
from lcm.next_state import get_next_state_function
from lcm.options import DefaultMapCoordinatesOptions
from lcm.simulate import simulate
from lcm.solve_brute import solve
from lcm.state_space import create_state_choice_space
Expand All @@ -29,7 +28,6 @@ def get_lcm_function(
*,
debug_mode: bool = True,
jit: bool = True,
interpolation_options: dict | None = None,
) -> tuple[Callable, ParamsDict]:
"""Entry point for users to get high level functions generated by lcm.
Expand All @@ -51,8 +49,6 @@ def get_lcm_function(
"solve_and_simulate" are supported.
debug_mode: Whether to log debug messages.
jit: Whether to jit the returned function.
interpolation_options: Dictionary of keyword arguments for interpolation
via map_coordinates. If None, the default options are used.
Returns:
- A function that takes params (and possibly other arguments, such as initial
Expand All @@ -68,7 +64,6 @@ def get_lcm_function(

_mod = process_model(model)
last_period = _mod.n_periods - 1
interpolation_options = DefaultMapCoordinatesOptions | (interpolation_options or {})

logger = get_logger(debug_mode)

Expand Down Expand Up @@ -133,7 +128,6 @@ def get_lcm_function(
model=_mod,
space_info=space_infos[period],
name_of_values_on_grid="vf_arr",
interpolation_options=interpolation_options,
period=period,
is_last_period=is_last_period,
)
Expand Down
16 changes: 2 additions & 14 deletions src/lcm/function_representation.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
from collections.abc import Callable
from functools import partial

import jax.numpy as jnp
from dags import concatenate_functions
from dags.signature import with_signature
from jax import Array
from jax.scipy.ndimage import map_coordinates

from lcm.functools import all_as_kwargs
from lcm.grids import ContinuousGrid
from lcm.interfaces import SpaceInfo
from lcm.typing import MapCoordinatesOptions
from lcm.ndimage import map_coordinates


def get_function_representation(
space_info: SpaceInfo,
name_of_values_on_grid: str,
*,
interpolation_options: MapCoordinatesOptions,
input_prefix: str = "",
) -> Callable[..., Array]:
"""Create a function representation of pre-calculated values on a grid.
Expand Down Expand Up @@ -61,9 +58,6 @@ def get_function_representation(
into the resulting function. In the value function case, this could be
'vf_arr', in which case, one would partial in 'vf_arr' into the
representation.
interpolation_options: Dictionary of interpolation options that will be passed
to jax.scipy.ndimage.map_coordinates. If None, DefaultMapCoordinatesOptions
will be used.
input_prefix: Prefix that will be added to all argument names of the resulting
function, except for the helpers arguments such as indexers or value arrays.
Default is the empty string. The prefix needs to contain the separator. E.g.
Expand Down Expand Up @@ -134,7 +128,6 @@ def get_function_representation(
funcs["__fval__"] = _get_interpolator(
name_of_values_on_grid="__interpolation_data__",
axis_names=_interpolation_axes,
map_coordinates_options=interpolation_options,
)

return concatenate_functions(
Expand Down Expand Up @@ -228,7 +221,6 @@ def find_coordinate(*args, **kwargs):
def _get_interpolator(
name_of_values_on_grid: str,
axis_names: list[str],
map_coordinates_options: MapCoordinatesOptions,
) -> Callable[..., Array]:
"""Create a function interpolator via named axes.
Expand All @@ -237,22 +229,18 @@ def _get_interpolator(
values, that have been evaluated on a grid, will be passed into the
resulting function.
axis_names: Names of the axes in the data array.
map_coordinates_options: Dictionary of interpolation options that will be passed
to jax.scipy.ndimage.map_coordinates.
Returns:
callable: A callable that interpolates a function via named axes.
"""
partialled_map_coordinates = partial(map_coordinates, **map_coordinates_options)

arg_names = [name_of_values_on_grid, *axis_names]

@with_signature(args=arg_names)
def interpolate(*args, **kwargs):
kwargs = all_as_kwargs(args, kwargs, arg_names=arg_names)
coordinates = jnp.array([kwargs[var] for var in axis_names])
return partialled_map_coordinates(
return map_coordinates(
input=kwargs[name_of_values_on_grid],
coordinates=coordinates,
)
Expand Down
2 changes: 0 additions & 2 deletions src/lcm/model_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def get_utility_and_feasibility_function(
model: InternalModel,
space_info,
name_of_values_on_grid,
interpolation_options,
period,
is_last_period,
):
Expand All @@ -45,7 +44,6 @@ def get_utility_and_feasibility_function(
scalar_value_function = get_function_representation(
space_info=space_info,
name_of_values_on_grid=name_of_values_on_grid,
interpolation_options=interpolation_options,
input_prefix="next_",
)

Expand Down
95 changes: 95 additions & 0 deletions src/lcm/ndimage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modifications made by Tim Mensinger, 2024

import functools
import itertools
import operator
from collections.abc import Sequence

import jax.numpy as jnp
from jax import Array, jit, lax, util


@jit
def map_coordinates(
input: Array,
coordinates: Sequence[Array],
) -> Array:
"""Map the input array to new coordinates using linear interpolation.
Modified from JAX implementation of :func:`scipy.ndimage.map_coordinates`.
Given an input array and a set of coordinates, this function returns the
interpolated values of the input array at those coordinates. For coordinates outside
the input array, linear extrapolation is used.
Args:
input: N-dimensional input array from which values are interpolated.
coordinates: length-N sequence of arrays specifying the coordinates
at which to evaluate the interpolated values
Returns:
The interpolated (extrapolated) values at the specified coordinates.
"""
if len(coordinates) != input.ndim:
raise ValueError(
"coordinates must be a sequence of length input.ndim, but "
f"{len(coordinates)} != {input.ndim}"
)

interpolation_data = [
_compute_indices_and_weights(coordinate, size)
for coordinate, size in util.safe_zip(coordinates, input.shape)
]

interpolation_values = []
for indices_and_weights in itertools.product(*interpolation_data):
indices, weights = util.unzip2(indices_and_weights)
contribution = input[indices]
weighted_value = _multiply_all(weights) * contribution
interpolation_values.append(weighted_value)

result = _sum_all(interpolation_values)

if jnp.issubdtype(input.dtype, jnp.integer):
result = _round_half_away_from_zero(result)

return result.astype(input.dtype)


def _compute_indices_and_weights(
coordinate: Array, input_size: int
) -> list[tuple[Array, Array]]:
"""Compute indices and weights for linear interpolation."""
lower_index = jnp.clip(jnp.floor(coordinate), 0, input_size - 2).astype(jnp.int32)
upper_weight = coordinate - lower_index
lower_weight = 1 - upper_weight
return [(lower_index, lower_weight), (lower_index + 1, upper_weight)]


def _multiply_all(arrs: Sequence[Array]) -> Array:
"""Multiply all arrays in the sequence."""
return functools.reduce(operator.mul, arrs)


def _sum_all(arrs: Sequence[Array]) -> Array:
"""Sum all arrays in the sequence."""
return functools.reduce(operator.add, arrs)


def _round_half_away_from_zero(a: Array) -> Array:
return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a)
7 changes: 0 additions & 7 deletions src/lcm/options.py

This file was deleted.

30 changes: 1 addition & 29 deletions src/lcm/typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Literal, TypedDict
from typing import Any, TypedDict

from jax import Array

Expand Down Expand Up @@ -31,31 +31,3 @@ class SegmentInfo(TypedDict):

segment_ids: Array
num_segments: int


class MapCoordinatesOptions(TypedDict):
"""Options passed to `jax.scipy.ndimage.map_coordinates`.
From the JAX documentation (as of 2024-06-16):
- "order": The order of interpolation. JAX supports the following:
- 0: Nearest-neighbor
- 1: Linear
- "mode": Points outside the boundaries of the input are filled according to the
given mode. JAX supports one of ('constant', 'nearest', 'mirror', 'wrap',
'reflect'). Note the 'wrap' mode in JAX behaves as 'grid-wrap' mode in SciPy, and
'constant' mode in JAX behaves as 'grid-constant' mode in SciPy. This discrepancy
was caused by a former bug in those modes in SciPy (scipy/scipy#2640), which was
first fixed in JAX by changing the behavior of the existing modes, and later on
fixed in SciPy, by adding modes with new names, rather than fixing the existing
ones, for backwards compatibility reasons.
- "cval": Value used for points outside the boundaries of the input if
mode='constant'.
"""

order: Literal[0, 1]
mode: Literal["constant", "nearest", "mirror", "wrap", "reflect"]
cval: Scalar
Binary file modified tests/data/regression_tests/simulation.pkl
Binary file not shown.
Binary file modified tests/data/regression_tests/solution.pkl
Binary file not shown.
4 changes: 0 additions & 4 deletions tests/test_entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ def test_create_compute_conditional_continuation_value():
model=model,
space_info=space_info,
name_of_values_on_grid="vf_arr",
interpolation_options={},
period=model.n_periods - 1,
is_last_period=True,
)
Expand Down Expand Up @@ -242,7 +241,6 @@ def test_create_compute_conditional_continuation_value_with_discrete_model():
model=model,
space_info=space_info,
name_of_values_on_grid="vf_arr",
interpolation_options={},
period=model.n_periods - 1,
is_last_period=True,
)
Expand Down Expand Up @@ -296,7 +294,6 @@ def test_create_compute_conditional_continuation_policy():
model=model,
space_info=space_info,
name_of_values_on_grid="vf_arr",
interpolation_options={},
period=model.n_periods - 1,
is_last_period=True,
)
Expand Down Expand Up @@ -346,7 +343,6 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model():
model=model,
space_info=space_info,
name_of_values_on_grid="vf_arr",
interpolation_options={},
period=model.n_periods - 1,
is_last_period=True,
)
Expand Down
8 changes: 0 additions & 8 deletions tests/test_function_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
IndexerInfo,
SpaceInfo,
)
from lcm.options import DefaultMapCoordinatesOptions


def test_function_evaluator_with_one_continuous_variable():
Expand All @@ -40,7 +39,6 @@ def test_function_evaluator_with_one_continuous_variable():
space_info=space_info,
name_of_values_on_grid="vf_arr",
input_prefix="next_",
interpolation_options=DefaultMapCoordinatesOptions,
)

# partial the function values into the evaluator
Expand All @@ -67,7 +65,6 @@ def test_function_evaluator_with_one_discrete_variable():
space_info=space_info,
name_of_values_on_grid="vf_arr",
input_prefix="next_",
interpolation_options=DefaultMapCoordinatesOptions,
)

# partial the function values into the evaluator
Expand Down Expand Up @@ -139,7 +136,6 @@ def test_function_evaluator():
evaluator = get_function_representation(
space_info=space_info,
name_of_values_on_grid="vf_arr",
interpolation_options=DefaultMapCoordinatesOptions,
)

# test the evaluator
Expand Down Expand Up @@ -217,7 +213,6 @@ def test_function_evaluator_longer_indexer():
evaluator = get_function_representation(
space_info=space_info,
name_of_values_on_grid="vf_arr",
interpolation_options=DefaultMapCoordinatesOptions,
)

# test the evaluator
Expand Down Expand Up @@ -282,7 +277,6 @@ def test_get_interpolator():
interpolate = _get_interpolator(
name_of_values_on_grid="vf",
axis_names=["wealth", "working"],
map_coordinates_options=DefaultMapCoordinatesOptions,
)

def _utility(wealth, working):
Expand Down Expand Up @@ -325,7 +319,6 @@ def test_get_function_evaluator_illustrative():
space_info=space_info,
name_of_values_on_grid="values_name",
input_prefix="prefix_",
interpolation_options=DefaultMapCoordinatesOptions,
)

# partial the function values into the evaluator
Expand Down Expand Up @@ -364,7 +357,6 @@ def test_get_interpolator_illustrative():
interpolate = _get_interpolator(
name_of_values_on_grid="test_name",
axis_names=["a", "b"],
map_coordinates_options=DefaultMapCoordinatesOptions,
)

def f(a, b):
Expand Down
Loading

0 comments on commit 7e8e3f8

Please sign in to comment.