Skip to content

Commit

Permalink
test: add MVP jax tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andnp committed Feb 10, 2025
1 parent 3e4a4e7 commit cceed53
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions tests/integration/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from functools import partial
import jax

Check failure on line 2 in tests/integration/test_jax.py

View workflow job for this annotation

GitHub Actions / check

Import "jax" could not be resolved (reportMissingImports)
import jax.experimental

Check failure on line 3 in tests/integration/test_jax.py

View workflow job for this annotation

GitHub Actions / check

Import "jax.experimental" could not be resolved (reportMissingImports)
from ml_instrumentation.Collector import Collector
from ml_instrumentation.Writer import SqlPoint

def test_collector_jax_jit(basic_collector: Collector):
basic_collector.set_experiment_id(0)
basic_collector.next_frame()

@partial(jax.jit, static_argnums=(1,))
def f(x: jax.Array, collector: Collector):
y = 2 * x
y = y.dot(y)
collector.collect_jax('m1', y)
return y

x = jax.numpy.array([0, 1, 2, 3])
y = f(x, basic_collector)

assert y == 56
assert basic_collector.get('m1', 0) == [
SqlPoint(frame=0, id=0, measurement=56)
]

def test_collector_jax_grad(basic_collector: Collector):
basic_collector.set_experiment_id(0)
basic_collector.next_frame()

@partial(jax.jit, static_argnums=(1,))
@jax.grad
def f(x: jax.Array, collector: Collector):
y = 2 * x
y = y.dot(y)
collector.collect_jax('m1', y)
return y

x = jax.numpy.array([0, 1, 2, 3.])
f(x, basic_collector)

assert basic_collector.get('m1', 0) == [
SqlPoint(frame=0, id=0, measurement=56)
]

def test_collector_jax_vmap(basic_collector: Collector):
basic_collector.set_experiment_id(0)
basic_collector.next_frame()

@partial(jax.jit, static_argnums=(1,))
@partial(jax.vmap, in_axes=(0, None), out_axes=0)
def f(x: jax.Array, collector: Collector):
y = 2 * x
y = y.dot(y)
collector.collect_jax('m1', y)
return y

x = jax.numpy.array([
[0, 1, 2, 3.],
[1, 1, 2, 3.],
])
f(x, basic_collector)

assert basic_collector.get('m1', 0) == [
SqlPoint(frame=0, id=0, measurement=56),
SqlPoint(frame=0, id=0, measurement=60),
]

0 comments on commit cceed53

Please sign in to comment.