Skip to content

Commit

Permalink
Update ndimage tests and add jit to map_coordinates
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 26, 2024
1 parent 642ba66 commit 1add307
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 11 deletions.
3 changes: 2 additions & 1 deletion src/lcm/ndimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from collections.abc import Sequence

import jax.numpy as jnp
from jax import Array, lax, util
from jax import Array, jit, lax, util


@jit
def map_coordinates(
input: Array,
coordinates: Sequence[Array],
Expand Down
10 changes: 0 additions & 10 deletions tests/test_ndimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,3 @@ def f(step):
# Gradient of f(step) is 2 * step
assert_allclose(jax.grad(f)(0.5), 1.0)
assert_allclose(jax.grad(f)(1.0), 2.0)


def test_extrapolation():
x = jnp.arange(3.0)
c = [jnp.array([-2.0, -1.0, 5.0, 10.0])]

got = lcm_map_coordinates(x, c)
expected = c[0]

assert_array_equal(got, expected)
19 changes: 19 additions & 0 deletions tests/test_ndimage_unit.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,33 @@
import jax.numpy as jnp
import pytest
from numpy.testing import assert_array_equal

from lcm.ndimage import (
_compute_indices_and_weights,
_multiply_all,
_round_half_away_from_zero,
_sum_all,
map_coordinates,
)


def test_map_coordinates_wrong_input_dimensions():
values = jnp.arange(2) # ndim = 1
coordinates = [jnp.array([0]), jnp.array([1])] # len = 2
with pytest.raises(ValueError, match="coordinates must be a sequence of length"):
map_coordinates(values, coordinates)


def test_map_coordinates_extrapolation():
x = jnp.arange(3.0)
c = [jnp.array([-2.0, -1.0, 5.0, 10.0])]

got = map_coordinates(x, c)
expected = c[0]

assert_array_equal(got, expected)


def test_nonempty_sum():
a = jnp.arange(3)

Expand Down

0 comments on commit 1add307

Please sign in to comment.