Skip to content

Commit

Permalink
Rewrite ndimage.py slightly and add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 25, 2024
1 parent cd46591 commit 642ba66
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 25 deletions.
39 changes: 21 additions & 18 deletions src/lcm/ndimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,40 +50,43 @@ def map_coordinates(
f"{len(coordinates)} != {input.ndim}"
)

valid_1d_interpolations = []
for coordinate, size in util.safe_zip(coordinates, input.shape):
interp_nodes = _linear_indices_and_weights(coordinate, input_size=size)
valid_1d_interpolations.append(interp_nodes)

outputs = []
for items in itertools.product(*valid_1d_interpolations):
indices, weights = util.unzip2(items)
interpolation_data = [
_compute_indices_and_weights(coordinate, size)
for coordinate, size in util.safe_zip(coordinates, input.shape)
]

interpolation_values = []
for indices_and_weights in itertools.product(*interpolation_data):
indices, weights = util.unzip2(indices_and_weights)
contribution = input[indices]
outputs.append(_nonempty_prod(weights) * contribution)
weighted_value = _multiply_all(weights) * contribution
interpolation_values.append(weighted_value)

result = _nonempty_sum(outputs)
result = _sum_all(interpolation_values)

if jnp.issubdtype(input.dtype, jnp.integer):
result = _round_half_away_from_zero(result)

return result.astype(input.dtype)


def _linear_indices_and_weights(
def _compute_indices_and_weights(
coordinate: Array, input_size: int
) -> tuple[tuple[Array, Array], tuple[Array, Array]]:
lower = jnp.clip(jnp.floor(coordinate), min=0, max=input_size - 2)
upper_weight = coordinate - lower
) -> list[tuple[Array, Array]]:
"""Compute indices and weights for linear interpolation."""
lower_index = jnp.clip(jnp.floor(coordinate), 0, input_size - 2).astype(jnp.int32)
upper_weight = coordinate - lower_index
lower_weight = 1 - upper_weight
index = lower.astype(jnp.int32)
return (index, lower_weight), (index + 1, upper_weight)
return [(lower_index, lower_weight), (lower_index + 1, upper_weight)]


def _nonempty_prod(arrs: Sequence[Array]) -> Array:
def _multiply_all(arrs: Sequence[Array]) -> Array:
"""Multiply all arrays in the sequence."""
return functools.reduce(operator.mul, arrs)


def _nonempty_sum(arrs: Sequence[Array]) -> Array:
def _sum_all(arrs: Sequence[Array]) -> Array:
"""Sum all arrays in the sequence."""
return functools.reduce(operator.add, arrs)


Expand Down
14 changes: 7 additions & 7 deletions tests/test_ndimage_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
from numpy.testing import assert_array_equal

from lcm.ndimage import (
_linear_indices_and_weights,
_nonempty_prod,
_nonempty_sum,
_compute_indices_and_weights,
_multiply_all,
_round_half_away_from_zero,
_sum_all,
)


def test_nonempty_sum():
a = jnp.arange(3)

expected = a + a + a
got = _nonempty_sum([a, a, a])
got = _sum_all([a, a, a])

assert_array_equal(got, expected)

Expand All @@ -22,7 +22,7 @@ def test_nonempty_prod():
a = jnp.arange(3)

expected = a * a * a
got = _nonempty_prod([a, a, a])
got = _multiply_all([a, a, a])

assert_array_equal(got, expected)

Expand All @@ -45,7 +45,7 @@ def test_linear_indices_and_weights_inside_domain():
"""Test that the indices and weights are correct for a points inside the domain."""
coordinates = jnp.array([0, 0.5, 1])

(idx_low, weight_low), (idx_high, weight_high) = _linear_indices_and_weights(
(idx_low, weight_low), (idx_high, weight_high) = _compute_indices_and_weights(
coordinates, input_size=2
)

Expand All @@ -58,7 +58,7 @@ def test_linear_indices_and_weights_inside_domain():
def test_linear_indices_and_weights_outside_domain():
coordinates = jnp.array([-1, 2])

(idx_low, weight_low), (idx_high, weight_high) = _linear_indices_and_weights(
(idx_low, weight_low), (idx_high, weight_high) = _compute_indices_and_weights(
coordinates, input_size=2
)

Expand Down

0 comments on commit 642ba66

Please sign in to comment.