-
-
Notifications
You must be signed in to change notification settings - Fork 14.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
python3Packages.jax: add missing cuda libraries #375186
base: master
Are you sure you want to change the base?
Conversation
Not so important, but while trying to debug I found commands like python -c "from jax._src.xla_bridge import _check_cuda_versions; _check_cuda_versions()" are failing with the following error message since they rely directly on jax-cuda12-plugin (and not jax-cuda12-pjrt). tracebackTraceback (most recent call last):
File "<string>", line 1, in <module>
File "/nix/store/hp5n2nk09rv4c6gj3pvqba387zclcvjq-python3-3.12.8-env/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 431, in _check_cuda_versions
raise RuntimeError(f'Unable to use CUDA because of the '
RuntimeError: Unable to use CUDA because of the following issues with CUDA components:
Unable to load CUDA. Is it installed?
Traceback (most recent call last):
File "/nix/store/hp5n2nk09rv4c6gj3pvqba387zclcvjq-python3-3.12.8-env/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 349, in _version_check
version = get_version()
^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:33: operation cudaRuntimeGetVersion(&version) failed: Error loading CUDA libraries. GPU will not be used.
--------------------------------------------------
Unable to load cuDNN. Is it installed?
Traceback (most recent call last):
File "/nix/store/hp5n2nk09rv4c6gj3pvqba387zclcvjq-python3-3.12.8-env/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 349, in _version_check
version = get_version()
^^^^^^^^^^^^^
RuntimeError: cuDNN not found.
--------------------------------------------------
Unable to load cuFFT. Is it installed?
Traceback (most recent call last):
File "/nix/store/hp5n2nk09rv4c6gj3pvqba387zclcvjq-python3-3.12.8-env/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 349, in _version_check
version = get_version()
^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:54: operation cufftGetVersion(&version) failed: cuFFT internal error
--------------------------------------------------
Unable to load cuSOLVER. Is it installed?
Traceback (most recent call last):
File "/nix/store/hp5n2nk09rv4c6gj3pvqba387zclcvjq-python3-3.12.8-env/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 349, in _version_check
version = get_version()
^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:61: operation cusolverGetVersion(&version) failed: cuSolver internal error
--------------------------------------------------
Unable to load cuPTI. Is it installed?
Traceback (most recent call last):
File "/nix/store/hp5n2nk09rv4c6gj3pvqba387zclcvjq-python3-3.12.8-env/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 349, in _version_check
version = get_version()
^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:47: operation cuptiGetVersion(&version) failed: Unknown CUPTI error 999. This probably means that JAX was unable to load cupti.
--------------------------------------------------
Unable to load cuBLAS. Is it installed?
Traceback (most recent call last):
File "/nix/store/hp5n2nk09rv4c6gj3pvqba387zclcvjq-python3-3.12.8-env/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 349, in _version_check
version = get_version()
^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:71: operation cublasGetVersion( nullptr, &version) failed: cuBlas internal error
--------------------------------------------------
Unable to load cuSPARSE. Is it installed?
Traceback (most recent call last):
File "/nix/store/hp5n2nk09rv4c6gj3pvqba387zclcvjq-python3-3.12.8-env/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 349, in _version_check
version = get_version()
^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:80: operation cusparseGetProperty(MAJOR_VERSION, &major) failed: The cuSPARSE library was not found. I tried using the same nixpkgs/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix Lines 77 to 79 in 5df4362
--- a/pkgs/development/python-modules/jax-cuda12-plugin/default.nix
+++ b/pkgs/development/python-modules/jax-cuda12-plugin/default.nix
@@ -3,6 +3,7 @@
stdenv,
buildPythonPackage,
fetchPypi,
+ addDriverRunpath,
autoPatchelfHook,
pypaInstallHook,
wheelUnpackHook,
@@ -15,6 +16,22 @@ let
inherit (cudaPackages) cudaVersion;
inherit (jaxlib) version;
+ cudaLibPath = lib.makeLibraryPath (
+ with cudaPackages;
+ [
+ (lib.getLib libcublas) # libcublas.so
+ (lib.getLib cuda_cupti) # libcupti.so
+ (lib.getLib cuda_cudart) # libcudart.so
+ (lib.getLib cudnn) # libcudnn.so
+ (lib.getLib libcufft) # libcufft.so
+ (lib.getLib libcusolver) # libcusolver.so
+ (lib.getLib libcusparse) # libcusparse.so
+ (lib.getLib nccl) # libnccl.so
+ (lib.getLib libnvjitlink) # libnvJitLink.so
+ (lib.getLib addDriverRunpath.driverLink) # libcuda.so
+ ]
+ );
+
getSrcFromPypi =
{
platform,
@@ -94,12 +111,20 @@ buildPythonPackage {
wheelUnpackHook
];
+ postInstall = ''
+ cp -r ${jax-cuda12-pjrt}/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc -T $out/${python.sitePackages}/jax_cuda12_plugin/cuda
+ '';
+
+ preInstallCheck = ''
+ patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so
+ '';
+
dependencies = [ jax-cuda12-pjrt ];
pythonImportsCheck = [ "jax_cuda12_plugin" ];
- # no tests
- doCheck = false;
+ # FIXME: there are no tests, but we need to run preInstallCheck above
+ doCheck = true;
meta = {
description = "JAX Plugin for CUDA12"; |
d004d00
to
5b55194
Compare
postInstall = '' | ||
mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin | ||
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin | ||
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin | ||
''; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is mostly copied from jax-cuda12-pjrt, with the exception of the nvvm folder. I'm not sure where to put it, because I haven't encountered any errors with it missing.
nixpkgs/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix
Lines 64 to 69 in 5df4362
postInstall = '' | |
mkdir -p $out/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc/bin | |
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc/bin/ptxas | |
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc/bin/nvlink | |
ln -s ${cudaPackages.cuda_nvcc}/nvvm $out/${python.sitePackages}/jax_plugins/nvidia/cuda_nvcc/nvvm | |
''; |
The error I was getting was from accidentally leaving a copy of jax_cuda12_plugin in the working directory which overwrote the nix version... oops. cusolver working so this PR is ready for review. |
5b55194
to
0906c43
Compare
0906c43
to
17b7964
Compare
@natsukium @GaetanLepage @samuela could I get a review of this PR? Thanks! |
Hi @stephen-huan , thanks for putting this together. IIRC we set up a test suite of CUDA-enabled tests for jax/jaxlib a while back. I'm curious why those tests didn't flag these things earlier? In any case, we should fill those gaps in the test suite while we're at it. |
62e7930
to
7988811
Compare
@samuela I've added tests, as suggested. The reason it wasn't caught by the existing test is quite simple: the current test only does matrix multiplication, which exercises cublas and a few other libraries, but doesn't use libcufft (fft), libcusolver (inverse, logdet, etc.), or libcusparse (sparse operations), so it doesn't error if those libraries are missing. The updated tests exercise every library* (tested by deleting the corresponding line in
cupti is profiling tools. nccl is communication collectives for multi-gpu, so hard to test. According to upstream,
but I haven't found an operation that actually requires it. (*libcusolver pulls in libcusparse so the test doesn't error if libcusparse is missing, but it doesn't hurt to be explicit.) |
I would also like to test xla custom calls, but as mentioned in the PR body I don't have a clean minimum working example for this. I'm planning on submitting a PR for jax-triton (which I think is an independently useful piece of software) and can test it there with a cuda test like jax's. |
976bf48
to
6981825
Compare
Pallas is also an option and built in to JAX if you would like to test this without depending on jax-triton. |
6981825
to
f9c16aa
Compare
Tried pallas, but couldn't reproduce the error. I suspect it has nothing to do with xla custom calls, since the compilation has already happened (and should not need a compiler) but rather is unique to the triton support in jaxlib since it jit compiles. // TODO(cjfj): Support `TRITON_PTXAS_PATH` environment variable?
int cc_major = compute_capability / 10;
int cc_minor = compute_capability % 10;
JAX_ASSIGN_OR_RETURN(
std::vector<uint8_t> module_image,
stream_executor::CompileGpuAsm(cc_major, cc_minor, ptx.data(),
stream_executor::GpuAsmOpts{})); Not sure why pallas doesn't trigger this codepath as pallas lowers to triton on gpu but considering |
Things done
It seems #369920 accidentally deleted a number of cuda libraries (this breaks
jax.numpy.fft.ftt
, for example), so I added the ones from upstream. Minimum working example is below.In addition, I think I'm running into #164176 (comment) as jax wants
ptxas
for a xla custom call. Apologies for the lack of a minimum working example as the ffi interface is relatively involved. I'm testing on the example in jax-triton using this derivation for jax-triton (which doesn't involve any c++ because it can leverage code from jaxlib). The full error isAs far as I can tell, the paths we can control are
jax_cuda12_plugin/../nvidia/cuda_nvcc/bin/ptxas
jax_cuda12_plugin/../../nvidia/cuda_nvcc/bin/ptxas
jax_cuda12_plugin/cuda/bin/ptxas
The first two paths match the structure of
jax-cuda12-pjrt
'sjax_plugins
but are frustratingly outside of thejax_cuda12_plugin
folder so I was forced to use the last path, hence the awkward copy. Not sure why we need to do this for custom call specifically; things likejax.jit
work without it.Lastly, the reason this PR is draft is I haven't gotten operations using cusolver like
jax.numpy.inv
,jax.numpy.slogdet
, etc. to work. Running the above example gives the errorSupposedly
CUSOLVER_STATUS_INTERNAL_ERROR
"is usually caused by acudaMemcpyAsync()
failure"? I tried running withXLA_PYTHON_CLIENT_PREALLOCATE=false
to no avail, cf. jax-ml/jax#8916, jax-ml/jax#12846, jax-ml/jax#21950. Maybe the memory is a red herring since I get the same error even when libcusolver is not provided, i.e. the current state of jax before this PR, so might be some version incompatibility? GPU is a NVIDIA GeForce RTX 4070 and I am using NixOS withwith driver version 550.142 and cuda version 12.4. I checked that the cuda libraries match upstream's minimums.
(the last two commits are some small cleanup commits.)
cc @natsukium @GaetanLepage @samuela
nix.conf
? (See Nix manual)sandbox = relaxed
sandbox = true
nix-shell -p nixpkgs-review --run "nixpkgs-review rev HEAD"
. Note: all changes have to be committed, also see nixpkgs-review usage./result/bin/
)Add a 👍 reaction to pull requests you find important.