From 1add30778dbb43ce008b3096fd3191338c2b8047 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 26 Sep 2024 14:14:28 +0200 Subject: [PATCH] Update ndimage tests and add jit to map_coordinates --- src/lcm/ndimage.py | 3 ++- tests/test_ndimage.py | 10 ---------- tests/test_ndimage_unit.py | 19 +++++++++++++++++++ 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/src/lcm/ndimage.py b/src/lcm/ndimage.py index cb19c50..3dbdd1a 100644 --- a/src/lcm/ndimage.py +++ b/src/lcm/ndimage.py @@ -20,9 +20,10 @@ from collections.abc import Sequence import jax.numpy as jnp -from jax import Array, lax, util +from jax import Array, jit, lax, util +@jit def map_coordinates( input: Array, coordinates: Sequence[Array], diff --git a/tests/test_ndimage.py b/tests/test_ndimage.py index 003efa8..1c0bf50 100644 --- a/tests/test_ndimage.py +++ b/tests/test_ndimage.py @@ -100,13 +100,3 @@ def f(step): # Gradient of f(step) is 2 * step assert_allclose(jax.grad(f)(0.5), 1.0) assert_allclose(jax.grad(f)(1.0), 2.0) - - -def test_extrapolation(): - x = jnp.arange(3.0) - c = [jnp.array([-2.0, -1.0, 5.0, 10.0])] - - got = lcm_map_coordinates(x, c) - expected = c[0] - - assert_array_equal(got, expected) diff --git a/tests/test_ndimage_unit.py b/tests/test_ndimage_unit.py index 5fbafe2..19e2029 100644 --- a/tests/test_ndimage_unit.py +++ b/tests/test_ndimage_unit.py @@ -1,4 +1,5 @@ import jax.numpy as jnp +import pytest from numpy.testing import assert_array_equal from lcm.ndimage import ( @@ -6,9 +7,27 @@ _multiply_all, _round_half_away_from_zero, _sum_all, + map_coordinates, ) +def test_map_coordinates_wrong_input_dimensions(): + values = jnp.arange(2) # ndim = 1 + coordinates = [jnp.array([0]), jnp.array([1])] # len = 2 + with pytest.raises(ValueError, match="coordinates must be a sequence of length"): + map_coordinates(values, coordinates) + + +def test_map_coordinates_extrapolation(): + x = jnp.arange(3.0) + c = [jnp.array([-2.0, -1.0, 5.0, 10.0])] + + got = map_coordinates(x, c) + expected = c[0] + + assert_array_equal(got, expected) + + def test_nonempty_sum(): a = jnp.arange(3)