Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python3Packages.jax: add missing cuda libraries #375186

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

stephen-huan
Copy link
Member

Things done

It seems #369920 accidentally deleted a number of cuda libraries (this breaks jax.numpy.fft.ftt, for example), so I added the ones from upstream. Minimum working example is below.

import jax.numpy as jnp
from jax import random

if __name__ == "__main__":
    rng = random.key(0)
    rng, subkey = random.split(rng)

    n = 1 << 10
    x = random.normal(subkey, (n, n))
    print(jnp.fft.fft(x))  # libcufft
    print(jnp.linalg.inv(x))  # libcusolver

In addition, I think I'm running into #164176 (comment) as jax wants ptxas for a xla custom call. Apologies for the lack of a minimum working example as the ffi interface is relatively involved. I'm testing on the example in jax-triton using this derivation for jax-triton (which doesn't involve any c++ because it can leverage code from jaxlib). The full error is

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CustomCall failed: Couldn't find a suitable version of ptxas. The following locations were considered: bin/ptxas, /nix/store/r1dw8c14c4mvw6b44y0hmvxajap7s6aq-python3-3.12.8-env/bin/ptxas, /nix/store/c9m6yd8fg1flz2j5r4bif1ib5j20a0cy-python3-3.12.8/bin/ptxas, /nix/store/0sgkksgr178brspwfgin207yidgpfw93-python3.12-wheel-0.45.1/bin/ptxas, /nix/store/n6ncnjqqxzhscvjw96jvg0z7r1qwxw4c-python3.12-cmake-3.30.5/bin/ptxas, /nix/store/l22xmzagzf4xbb4jb25sn3aa3wm49x42-python3.12-ninja-1.12.1/bin/ptxas, /nix/store/ll40w9mfpybkn17vs148wz47226i7bbh-python3.12-lit-18.1.8/bin/ptxas, /nix/store/gdp64ps5j34q5n7c0ajb7dbjinnsyv8m-triton-llvm-19.1.0-rc1/bin/ptxas, /nix/store/3lbb9lx1gmis6qm8clq9g9k37i5l8l52-ncurses-6.4.20221231-dev/bin/ptxas, /nix/store/wm1qn5jqrxpcjkc640gq8a90ns5gw3cn-ncurses-6.4.20221231/bin/ptxas, /nix/store/srfxqk119fijwnprgsqvn68ys9kiw0bn-patchelf-0.15.0/bin/ptxas, /nix/store/xcn9p4xxfbvlkpah7pwchpav4ab9d135-gcc-wrapper-14-20241116/bin/ptxas, /nix/store/l89iqc7am6i60y8vk507zwrzxf0wcd3v-gcc-14-20241116/bin/ptxas, /nix/store/1c6bmxrrhm8bd26ai2rjqld2yyjrxhds-glibc-2.40-36-bin/bin/ptxas, /nix/store/4s9rah4cwaxflicsk5cndnknqlk9n4p3-coreutils-9.5/bin/ptxas, /nix/store/srcmmqi8kxjfygd0hyy42c8hv6cws83b-binutils-wrapper-2.43.1/bin/ptxas, /nix/store/j7p46r8v9gcpbxx89pbqlh61zhd33gzv-binutils-2.43.1/bin/ptxas, /nix/store/4s9rah4cwaxflicsk5cndnknqlk9n4p3-coreutils-9.5/bin/ptxas, /nix/store/jqrz1vq5nz4lnv9pqzydj0ir58wbjfy1-findutils-4.10.0/bin/ptxas, /nix/store/00g69vw7c9lycy63h45ximy0wmzqx5y6-diffutils-3.10/bin/ptxas, /nix/store/abm77lnrkrkb58z6xp1qwjcr1xgkcfwm-gnused-4.9/bin/ptxas, /nix/store/aap6cq56amx4mzbyxp2wpgsf1kqjcr1f-gnugrep-3.11/bin/ptxas, /nix/store/a3c47r5z1q2c4rz0kvq8hlilkhx2s718-gawk-5.3.1/bin/ptxas, /nix/store/9cwwj1c9csmc85l2cqzs3h9hbf1vwl6c-gnutar-1.35/bin/ptxas, /nix/store/nvvj6sk0k6px48436drlblf4gafgbvzr-gzip-1.13/bin/ptxas, /nix/store/mglixp03lsp0w986svwdvm7vcy17rdax-bzip2-1.0.8-bin/bin/ptxas, /nix/store/fp6cjl1zcmm6mawsnrb5yak1wkz2ma8l-gnumake-4.4.1/bin/ptxas, /nix/store/5mh7kaj2fyv8mk4sfq1brwxgc02884wi-bash-5.2p37/bin/ptxas, /nix/store/5yja5dpk2qw1v5mbfbl2d7klcdfrh90w-patch-2.7.6/bin/ptxas, /nix/store/h18s640fnhhj2qdh5vivcfbxvz377srg-xz-5.6.3-bin/bin/ptxas, /nix/store/c4rj90r2m89rxs64hmm857mipwjhig5d-file-5.46/bin/ptxas, /keep/home/slhuan/programs/nix/pinpkgs/.direnv/bin/ptxas, /home/slhuan/bin/ptxas, /nix/store/c9m6yd8fg1flz2j5r4bif1ib5j20a0cy-python3-3.12.8/bin/ptxas, /nix/store/1drw3gx2q4chryrh6qy62ikhdh0p9392-ranger-1.9.4/bin/ptxas, /nix/store/3cp5gn414sb3v9nzznpi0jq33pj1nr3k-less-668/bin/ptxas, /nix/store/ks6xg51nxksqy772bj3y5682m0xqa689-file-5.46/bin/ptxas, /nix/store/s77ik3wgzf99lzam0m9fgkjjvss420ly-imagemagick-7.1.1-43-dev/bin/ptxas, /nix/store/g26bs63hz87c9s3sg0v42d6d8gjk36qw-curl-8.11.0-dev/bin/ptxas, /nix/store/kcd3nbf8iklwi5djyjyqg4k55cgf0xbp-brotli-1.1.0/bin/ptxas, /nix/store/xdmmmg9cpbp207fgyviwdixc6pwgrm4c-krb5-1.21.3-dev/bin/ptxas, /nix/store/zqkizm9zn19bf9n58xcfxl4ks7yfs8my-krb5-1.21.3-lib/bin/ptxas, /nix/store/l0gvngj52vp9g2fc131rl6289kc3hwnl-krb5-1.21.3/bin/ptxas, /nix/store/3813yqzj1adyilq0g5qdffrj8s506xmq-nghttp2-1.64.0/bin/ptxas, /nix/store/chm94r4s4i1nn6v20idna7h03i4k49i3-libidn2-2.3.7-bin/bin/ptxas, /nix/store/lygl27c44xv73kx1spskcgvzwq7z337c-openssl-3.3.2-bin/bin/ptxas, /nix/store/aqr1xjmsi8xn0kwhjzmc18s251kh7xib-libpsl-0.21.5-bin/bin/ptxas, /nix/store/gxn6if7mzv1fqq2hlsl9nfqvx5ahszg7-zstd-1.5.6-bin/bin/ptxas, /nix/store/s0zynhz13rhg0gg8cy0ga4c0lnv3yqdn-zstd-1.5.6/bin/ptxas, /nix/store/yckhqngmx90bakzhcyrsk7dww1fm352s-curl-8.11.0-bin/bin/ptxas, /nix/store/mglixp03lsp0w986svwdvm7vcy17rdax-bzip2-1.0.8-bin/bin/ptxas, /nix/store/wc97xp20zf23vx6rj65ah6knqpqwn9bn-freetype-2.13.3-dev/bin/ptxas, /nix/store/j0cx534qlwqdinajjm4f7l3c2460785l-libpng-apng-1.6.43-dev/bin/ptxas, /nix/store/v4kqr0bx2d6acb2bxmj4ck71mxjy9i10-libjpeg-turbo-3.0.4-bin/bin/ptxas, /nix/store/03vb8v0mmr98b2rh11klsyd5xhmk0wd2-libdeflate-1.22/bin/ptxas, /nix/store/mmah8dm69x4rj9zhpf0nawrlr6wwnmpp-libwebp-1.4.0/bin/ptxas, /nix/store/h18s640fnhhj2qdh5vivcfbxvz377srg-xz-5.6.3-bin/bin/ptxas, /nix/store/sf8b71w6f9r0fb3ya6ymnn7gpxdrjfzw-libtiff-4.7.0-bin/bin/ptxas, /nix/store/br6d30xcg13w6jjsj5rlp51779skipjc-lcms2-2.16-bin/bin/ptxas, /nix/store/75v06fahb45kf0rcq37957n93jxka2cf-libwebp-1.4.0/bin/ptxas, /nix/store/rjjzv280allyyi10z049fkpsjpmwfj42-fftw-double-3.3.10-dev/bin/ptxas, /nix/store/wj4gp2b1mks14ry6m25rfwkpbj0qwvc0-imagemagick-7.1.1-43/bin/ptxas, /nix/store/6hlrwk6ia8p127a8vzhg9vsmx3fs5kfm-python3.12-chardet-5.2.0/bin/ptxas, /run/wrappers/bin/ptxas, /home/slhuan/.nix-profile/bin/ptxas, /nix/profile/bin/ptxas, /home/slhuan/.local/state/nix/profile/bin/ptxas, /etc/profiles/per-user/slhuan/bin/ptxas, /nix/var/nix/profiles/default/bin/ptxas, /run/current-system/sw/bin/ptxas, add.py.runfiles/cuda_nvcc/bin/ptxas, add.py/cuda_nvcc/bin/ptxas, bin/ptxas, /usr/local/cuda/bin/ptxas, /nix/store/r1dw8c14c4mvw6b44y0hmvxajap7s6aq-python3-3.12.8-env/lib/python3.12/site-packages/jax_cuda12_plugin/../nvidia/cuda_nvcc/bin/ptxas, /nix/store/r1dw8c14c4mvw6b44y0hmvxajap7s6aq-python3-3.12.8-env/lib/python3.12/site-packages/jax_cuda12_plugin/../../nvidia/cuda_nvcc/bin/ptxas, /nix/store/r1dw8c14c4mvw6b44y0hmvxajap7s6aq-python3-3.12.8-env/lib/python3.12/site-packages/jax_cuda12_plugin/cuda/bin/ptxas

As far as I can tell, the paths we can control are

  • jax_cuda12_plugin/../nvidia/cuda_nvcc/bin/ptxas
  • jax_cuda12_plugin/../../nvidia/cuda_nvcc/bin/ptxas
  • jax_cuda12_plugin/cuda/bin/ptxas

The first two paths match the structure of jax-cuda12-pjrt's jax_plugins but are frustratingly outside of the jax_cuda12_plugin folder so I was forced to use the last path, hence the awkward copy. Not sure why we need to do this for custom call specifically; things like jax.jit work without it.

Lastly, the reason this PR is draft is I haven't gotten operations using cusolver like jax.numpy.inv, jax.numpy.slogdet, etc. to work. Running the above example gives the error

E0119 18:00:52.316803  330063 pjrt_stream_executor_client.cc:3086] Execution of replica 0 failed: INTERNAL: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
Traceback (most recent call last):
  File "...", line 12, in <module>
    print(jnp.linalg.inv(x))  # libcusolver
          ^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error

Supposedly CUSOLVER_STATUS_INTERNAL_ERROR "is usually caused by a cudaMemcpyAsync() failure"? I tried running with XLA_PYTHON_CLIENT_PREALLOCATE=false to no avail, cf. jax-ml/jax#8916, jax-ml/jax#12846, jax-ml/jax#21950. Maybe the memory is a red herring since I get the same error even when libcusolver is not provided, i.e. the current state of jax before this PR, so might be some version incompatibility? GPU is a NVIDIA GeForce RTX 4070 and I am using NixOS with

{
  hardware.nvidia.open = true;
  services.xserver.videoDrivers = [ "nvidia" ];
}

with driver version 550.142 and cuda version 12.4. I checked that the cuda libraries match upstream's minimums.

(the last two commits are some small cleanup commits.)

cc @natsukium @GaetanLepage @samuela

  • Built on platform(s)
    • x86_64-linux
    • aarch64-linux
    • x86_64-darwin
    • aarch64-darwin
  • For non-Linux: Is sandboxing enabled in nix.conf? (See Nix manual)
    • sandbox = relaxed
    • sandbox = true
  • Tested, as applicable:
  • Tested compilation of all packages that depend on this change using nix-shell -p nixpkgs-review --run "nixpkgs-review rev HEAD". Note: all changes have to be committed, also see nixpkgs-review usage
  • Tested basic functionality of all binary files (usually in ./result/bin/)
  • 25.05 Release Notes (or backporting 24.11 and 25.05 Release notes)
    • (Package updates) Added a release notes entry if the change is major or breaking
    • (Module updates) Added a release notes entry if the change is significant
    • (Module addition) Added a release notes entry if adding a new NixOS module
  • Fits CONTRIBUTING.md.

Add a 👍 reaction to pull requests you find important.

@github-actions github-actions bot added 6.topic: python 10.rebuild-darwin: 0 This PR does not cause any packages to rebuild on Darwin 10.rebuild-linux: 1-10 labels Jan 19, 2025
@nix-owners nix-owners bot requested a review from natsukium January 19, 2025 23:17
@stephen-huan
Copy link
Member Author

Not so important, but while trying to debug I found commands like

python -c "from jax._src.xla_bridge import _check_cuda_versions; _check_cuda_versions()"

are failing with the following error message since they rely directly on jax-cuda12-plugin (and not jax-cuda12-pjrt).

traceback
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/nix/store/hp5n2nk09rv4c6gj3pvqba387zclcvjq-python3-3.12.8-env/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 431, in _check_cuda_versions
    raise RuntimeError(f'Unable to use CUDA because of the '
RuntimeError: Unable to use CUDA because of the following issues with CUDA components:
Unable to load CUDA. Is it installed?
Traceback (most recent call last):
  File "/nix/store/hp5n2nk09rv4c6gj3pvqba387zclcvjq-python3-3.12.8-env/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 349, in _version_check
    version = get_version()
              ^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:33: operation cudaRuntimeGetVersion(&version) failed: Error loading CUDA libraries. GPU will not be used.

--------------------------------------------------
Unable to load cuDNN. Is it installed?
Traceback (most recent call last):
  File "/nix/store/hp5n2nk09rv4c6gj3pvqba387zclcvjq-python3-3.12.8-env/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 349, in _version_check
    version = get_version()
              ^^^^^^^^^^^^^
RuntimeError: cuDNN not found.

--------------------------------------------------
Unable to load cuFFT. Is it installed?
Traceback (most recent call last):
  File "/nix/store/hp5n2nk09rv4c6gj3pvqba387zclcvjq-python3-3.12.8-env/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 349, in _version_check
    version = get_version()
              ^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:54: operation cufftGetVersion(&version) failed: cuFFT internal error

--------------------------------------------------
Unable to load cuSOLVER. Is it installed?
Traceback (most recent call last):
  File "/nix/store/hp5n2nk09rv4c6gj3pvqba387zclcvjq-python3-3.12.8-env/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 349, in _version_check
    version = get_version()
              ^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:61: operation cusolverGetVersion(&version) failed: cuSolver internal error

--------------------------------------------------
Unable to load cuPTI. Is it installed?
Traceback (most recent call last):
  File "/nix/store/hp5n2nk09rv4c6gj3pvqba387zclcvjq-python3-3.12.8-env/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 349, in _version_check
    version = get_version()
              ^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:47: operation cuptiGetVersion(&version) failed: Unknown CUPTI error 999. This probably means that JAX was unable to load cupti.

--------------------------------------------------
Unable to load cuBLAS. Is it installed?
Traceback (most recent call last):
  File "/nix/store/hp5n2nk09rv4c6gj3pvqba387zclcvjq-python3-3.12.8-env/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 349, in _version_check
    version = get_version()
              ^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:71: operation cublasGetVersion( nullptr, &version) failed: cuBlas internal error

--------------------------------------------------
Unable to load cuSPARSE. Is it installed?
Traceback (most recent call last):
  File "/nix/store/hp5n2nk09rv4c6gj3pvqba387zclcvjq-python3-3.12.8-env/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 349, in _version_check
    version = get_version()
              ^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:80: operation cusparseGetProperty(MAJOR_VERSION, &major) failed: The cuSPARSE library was not found.

I tried using the same patchelf strategy as in jax-cuda12-pjrt, but it's not working... not sure what's going wrong.

preInstallCheck = ''
patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_plugins/xla_cuda12/xla_cuda_plugin.so
'';

--- a/pkgs/development/python-modules/jax-cuda12-plugin/default.nix
+++ b/pkgs/development/python-modules/jax-cuda12-plugin/default.nix
@@ -3,6 +3,7 @@
   stdenv,
   buildPythonPackage,
   fetchPypi,
+  addDriverRunpath,
   autoPatchelfHook,
   pypaInstallHook,
   wheelUnpackHook,
@@ -15,6 +16,22 @@ let
   inherit (cudaPackages) cudaVersion;
   inherit (jaxlib) version;
 
+  cudaLibPath = lib.makeLibraryPath (
+    with cudaPackages;
+    [
+      (lib.getLib libcublas) # libcublas.so
+      (lib.getLib cuda_cupti) # libcupti.so
+      (lib.getLib cuda_cudart) # libcudart.so
+      (lib.getLib cudnn) # libcudnn.so
+      (lib.getLib libcufft) # libcufft.so
+      (lib.getLib libcusolver) # libcusolver.so
+      (lib.getLib libcusparse) # libcusparse.so
+      (lib.getLib nccl) # libnccl.so
+      (lib.getLib libnvjitlink) # libnvJitLink.so
+      (lib.getLib addDriverRunpath.driverLink) # libcuda.so
+    ]
+  );
+
   getSrcFromPypi =
     {
       platform,
@@ -94,12 +111,20 @@ buildPythonPackage {
     wheelUnpackHook
   ];
 
+  postInstall = ''
+    cp -r ${jax-cuda12-pjrt}/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc -T $out/${python.sitePackages}/jax_cuda12_plugin/cuda
+  '';
+
+  preInstallCheck = ''
+    patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so
+  '';
+
   dependencies = [ jax-cuda12-pjrt ];
 
   pythonImportsCheck = [ "jax_cuda12_plugin" ];
 
-  # no tests
-  doCheck = false;
+  # FIXME: there are no tests, but we need to run preInstallCheck above
+  doCheck = true;
 
   meta = {
     description = "JAX Plugin for CUDA12";

@stephen-huan stephen-huan marked this pull request as ready for review January 20, 2025 05:53
Comment on lines +119 to +108
postInstall = ''
mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
'';
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly copied from jax-cuda12-pjrt, with the exception of the nvvm folder. I'm not sure where to put it, because I haven't encountered any errors with it missing.

postInstall = ''
mkdir -p $out/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc/bin
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc/bin/ptxas
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc/bin/nvlink
ln -s ${cudaPackages.cuda_nvcc}/nvvm $out/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc/nvvm
'';

@stephen-huan
Copy link
Member Author

The error I was getting was from accidentally leaving a copy of jax_cuda12_plugin in the working directory which overwrote the nix version... oops. cusolver working so this PR is ready for review.

@stephen-huan
Copy link
Member Author

@natsukium @GaetanLepage @samuela could I get a review of this PR? Thanks!

@samuela
Copy link
Member

samuela commented Feb 2, 2025

Hi @stephen-huan , thanks for putting this together. IIRC we set up a test suite of CUDA-enabled tests for jax/jaxlib a while back. I'm curious why those tests didn't flag these things earlier? In any case, we should fill those gaps in the test suite while we're at it.

@stephen-huan stephen-huan force-pushed the jax-missing-cuda branch 2 times, most recently from 62e7930 to 7988811 Compare February 2, 2025 04:13
@stephen-huan
Copy link
Member Author

@samuela I've added tests, as suggested. The reason it wasn't caught by the existing test is quite simple: the current test only does matrix multiplication, which exercises cublas and a few other libraries, but doesn't use libcufft (fft), libcusolver (inverse, logdet, etc.), or libcusparse (sparse operations), so it doesn't error if those libraries are missing.

The updated tests exercise every library* (tested by deleting the corresponding line in cudaLibPath) but

  • cuda_cupti
  • nccl
  • libnvjitlink

cupti is profiling tools. nccl is communication collectives for multi-gpu, so hard to test. According to upstream,

nvjitlink is not a direct dependency of JAX, but it is a transitive dependency via, for example, cuSOLVER.

but I haven't found an operation that actually requires it.

(*libcusolver pulls in libcusparse so the test doesn't error if libcusparse is missing, but it doesn't hurt to be explicit.)

@stephen-huan
Copy link
Member Author

I would also like to test xla custom calls, but as mentioned in the PR body I don't have a clean minimum working example for this. I'm planning on submitting a PR for jax-triton (which I think is an independently useful piece of software) and can test it there with a cuda test like jax's.

@stephen-huan stephen-huan force-pushed the jax-missing-cuda branch 2 times, most recently from 976bf48 to 6981825 Compare February 2, 2025 04:47
@samuela
Copy link
Member

samuela commented Feb 2, 2025

I would also like to test xla custom calls, but as mentioned in the PR body I don't have a clean minimum working example for this. I'm planning on submitting a PR for jax-triton (which I think is an independently useful piece of software) and can test it there with a cuda test like jax's.

Pallas is also an option and built in to JAX if you would like to test this without depending on jax-triton.

@stephen-huan
Copy link
Member Author

Tried pallas, but couldn't reproduce the error. I suspect it has nothing to do with xla custom calls, since the compilation has already happened (and should not need a compiler) but rather is unique to the triton support in jaxlib since it jit compiles.

  // TODO(cjfj): Support `TRITON_PTXAS_PATH` environment variable?
  int cc_major = compute_capability / 10;
  int cc_minor = compute_capability % 10;
  JAX_ASSIGN_OR_RETURN(
      std::vector<uint8_t> module_image,
      stream_executor::CompileGpuAsm(cc_major, cc_minor, ptx.data(),
                                     stream_executor::GpuAsmOpts{}));

Not sure why pallas doesn't trigger this codepath as pallas lowers to triton on gpu but considering jax.jit works, it's possible the pallas backend is compiled "jit but more aot" than the jax-triton approach (if that makes any sense). It's also not clear to me what the difference between jax-cuda12-pjrt and jax-cuda12-plugin is, and why jax successfully finds ptxas in jax-cuda12-pjrt for jax.jit but only searches jax-cuda12-plugin for the jax-triton triton kernel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
6.topic: python 10.rebuild-darwin: 0 This PR does not cause any packages to rebuild on Darwin 10.rebuild-linux: 1-10
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants