Skip to content

Commit

Permalink
python3Packages.jax: add operations to cuda test
Browse files Browse the repository at this point in the history
  • Loading branch information
stephen-huan committed Feb 2, 2025
1 parent 17b7964 commit 6981825
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions pkgs/development/python-modules/jax/test-cuda.nix
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,18 @@ pkgs.writers.writePython3Bin "jax-test-cuda"
}
''
import jax
import jax.numpy as jnp
from jax import random
from jax.experimental import sparse
assert jax.devices()[0].platform == "gpu"
assert jax.devices()[0].platform == "gpu" # libcuda.so
rng = random.PRNGKey(0)
rng = random.key(0) # libcudart.so, libcudnn.so
x = random.normal(rng, (100, 100))
x @ x
x @ x # libcublas.so
jnp.fft.fft(x) # libcufft.so
jnp.linalg.inv(x) # libcusolver.so
sparse.CSR.fromdense(x) @ x # libcusparse.so
print("success!")
''

0 comments on commit 6981825

Please sign in to comment.