Skip to content

Commit

Permalink
Update ndimage tests and add module for unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 25, 2024
1 parent 2c65c3a commit cd46591
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 32 deletions.
7 changes: 0 additions & 7 deletions src/lcm/ndimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,6 @@ def map_coordinates(
f"{len(coordinates)} != {input.ndim}"
)

return _map_coordinates(input, coordinates)


def _map_coordinates(
input: Array,
coordinates: Sequence[Array],
) -> Array:
valid_1d_interpolations = []
for coordinate, size in util.safe_zip(coordinates, input.shape):
interp_nodes = _linear_indices_and_weights(coordinate, input_size=size)
Expand Down
36 changes: 11 additions & 25 deletions tests/test_ndimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
scipy_map_coordinates = partial(scipy.ndimage.map_coordinates, order=1, cval=0)
lcm_map_coordinates = lcm.ndimage.map_coordinates

JAX_IMPLEMENTATIONS = [jax_map_coordinates, lcm_map_coordinates]
JAX_BASED_IMPLEMENTATIONS = [jax_map_coordinates, lcm_map_coordinates]


TEST_SHAPES = [
Expand All @@ -51,14 +51,14 @@ def _make_test_data(shape, coordinates_shape, dtype):
return x, c


@pytest.mark.parametrize("map_coordinates", JAX_IMPLEMENTATIONS)
@pytest.mark.parametrize("map_coordinates", JAX_BASED_IMPLEMENTATIONS)
@pytest.mark.parametrize("shape", TEST_SHAPES)
@pytest.mark.parametrize("coordinates_shape", TEST_COORDINATES_SHAPES)
@pytest.mark.parametrize("dtype", [np.int64, np.float64])
def test_map_coordinates_against_scipy(
map_coordinates, shape, coordinates_shape, dtype
):
"""Test that all libraries implement same behavior with integer input."""
"""Test that JAX and LCM implementations behave as scipy."""
x, c = _make_test_data(shape, coordinates_shape, dtype=dtype)

x_jax = jnp.asarray(x)
Expand All @@ -70,25 +70,11 @@ def test_map_coordinates_against_scipy(
assert_array_almost_equal(got, expected, decimal=14)


@pytest.mark.parametrize("map_coordinates", JAX_IMPLEMENTATIONS)
def test_map_coordinates_round_half_integer_input(map_coordinates):
"""Test that all libraries implement same rounding behavior with integer input."""
x = np.arange(-5, 5, dtype=np.int64)
c = np.array([[0.5, 1.5, 2.5, 6.5, 8.5]])

x_jax = jnp.asarray(x)
c_jax = [jnp.asarray(c_i) for c_i in c]

expected = scipy_map_coordinates(x, c)
got = map_coordinates(x_jax, c_jax)

assert_array_equal(got, expected)


@pytest.mark.parametrize("map_coordinates", JAX_IMPLEMENTATIONS)
def test_map_coordinates_round_half_float_input(map_coordinates):
"""Test that all libraries implement same rounding behavior with float input."""
x = np.arange(-5, 5, dtype=np.float64)
@pytest.mark.parametrize("map_coordinates", JAX_BASED_IMPLEMENTATIONS)
@pytest.mark.parametrize("dtype", [np.int64, np.float64])
def test_map_coordinates_round_half_against_scipy(map_coordinates, dtype):
"""Test that JAX and LCM implementations round as scipy."""
x = np.arange(-5, 5, dtype=dtype)
c = np.array([[0.5, 1.5, 2.5, 6.5, 8.5]])

x_jax = jnp.asarray(x)
Expand All @@ -100,9 +86,9 @@ def test_map_coordinates_round_half_float_input(map_coordinates):
assert_array_equal(got, expected)


@pytest.mark.parametrize("map_coordinates", JAX_IMPLEMENTATIONS)
@pytest.mark.parametrize("map_coordinates", JAX_BASED_IMPLEMENTATIONS)
def test_gradients(map_coordinates):
"""Test that JAX based implementations exhibit same gradient behavior."""
"""Test that JAX and LCM implementations exhibit same gradient behavior."""
x = jnp.arange(9.0)
border = 3 # square root of 9, as we are considering a parabola on x.

Expand All @@ -120,7 +106,7 @@ def test_extrapolation():
x = jnp.arange(3.0)
c = [jnp.array([-2.0, -1.0, 5.0, 10.0])]

got = lcm.ndimage.map_coordinates(x, c)
got = lcm_map_coordinates(x, c)
expected = c[0]

assert_array_equal(got, expected)
68 changes: 68 additions & 0 deletions tests/test_ndimage_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import jax.numpy as jnp
from numpy.testing import assert_array_equal

from lcm.ndimage import (
_linear_indices_and_weights,
_nonempty_prod,
_nonempty_sum,
_round_half_away_from_zero,
)


def test_nonempty_sum():
a = jnp.arange(3)

expected = a + a + a
got = _nonempty_sum([a, a, a])

assert_array_equal(got, expected)


def test_nonempty_prod():
a = jnp.arange(3)

expected = a * a * a
got = _nonempty_prod([a, a, a])

assert_array_equal(got, expected)


def test_round_half_away_from_zero_integer():
a = jnp.array([1, 2], dtype=jnp.int32)
assert_array_equal(_round_half_away_from_zero(a), a)


def test_round_half_away_from_zero_float():
a = jnp.array([0.5, 1.5], dtype=jnp.float32)

expected = jnp.array([1, 2], dtype=jnp.int32)
got = _round_half_away_from_zero(a)

assert_array_equal(got, expected)


def test_linear_indices_and_weights_inside_domain():
"""Test that the indices and weights are correct for a points inside the domain."""
coordinates = jnp.array([0, 0.5, 1])

(idx_low, weight_low), (idx_high, weight_high) = _linear_indices_and_weights(
coordinates, input_size=2
)

assert_array_equal(idx_low, jnp.array([0, 0, 0], dtype=jnp.int32))
assert_array_equal(weight_low, jnp.array([1, 0.5, 0], dtype=jnp.float32))
assert_array_equal(idx_high, jnp.array([1, 1, 1], dtype=jnp.int32))
assert_array_equal(weight_high, jnp.array([0, 0.5, 1], dtype=jnp.float32))


def test_linear_indices_and_weights_outside_domain():
coordinates = jnp.array([-1, 2])

(idx_low, weight_low), (idx_high, weight_high) = _linear_indices_and_weights(
coordinates, input_size=2
)

assert_array_equal(idx_low, jnp.array([0, 0], dtype=jnp.int32))
assert_array_equal(weight_low, jnp.array([2, -1], dtype=jnp.float32))
assert_array_equal(idx_high, jnp.array([1, 1], dtype=jnp.int32))
assert_array_equal(weight_high, jnp.array([-1, 2], dtype=jnp.float32))

0 comments on commit cd46591

Please sign in to comment.