diff --git a/cuda_bindings/cuda/bindings/_path_finder/find_nvidia_headers.py b/cuda_bindings/cuda/bindings/_path_finder/find_nvidia_headers.py new file mode 100644 index 000000000..befd4ca81 --- /dev/null +++ b/cuda_bindings/cuda/bindings/_path_finder/find_nvidia_headers.py @@ -0,0 +1,42 @@ +# Copyright 2025 NVIDIA Corporation. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +import functools +import glob +import os + +from cuda.bindings._path_finder.find_sub_dirs import find_sub_dirs_all_sitepackages +from cuda.bindings._path_finder.supported_libs import IS_WINDOWS + + +@functools.cache +def find_nvidia_header_directory(libname: str) -> str: + if libname != "nvshmem": + raise RuntimeError(f"UNKNOWN {libname=}") + + if libname == "nvshmem" and IS_WINDOWS: + # nvshmem has no Windows support. + return None + + # Installed from a wheel + nvidia_sub_dirs = ("nvidia", "nvshmem", "include") + for hdr_dir in find_sub_dirs_all_sitepackages(nvidia_sub_dirs): + nvshmem_h_path = os.path.join(hdr_dir, "nvshmem.h") + if os.path.isfile(nvshmem_h_path): + return hdr_dir + + conda_prefix = os.environ.get("CONDA_PREFIX") + if conda_prefix and os.path.isdir(conda_prefix): + hdr_dir = os.path.join(conda_prefix, "include") + if os.path.isdir(hdr_dir): + nvshmem_h_path = os.path.join(hdr_dir, "nvshmem.h") + if os.path.isfile(nvshmem_h_path): + return hdr_dir + + for hdr_dir in sorted(glob.glob("/usr/include/nvshmem_*")): + if os.path.isdir(hdr_dir): + nvshmem_h_path = os.path.join(hdr_dir, "nvshmem.h") + if os.path.isfile(nvshmem_h_path): + return hdr_dir + + return None diff --git a/cuda_bindings/tests/test_path_finder_find_headers.py b/cuda_bindings/tests/test_path_finder_find_headers.py new file mode 100644 index 000000000..d76081789 --- /dev/null +++ b/cuda_bindings/tests/test_path_finder_find_headers.py @@ -0,0 +1,30 @@ +# Copyright 2025 NVIDIA Corporation. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +import pytest + +from cuda.bindings._path_finder import find_nvidia_headers + + +def test_find_nvidia_header_directory(info_summary_append): + with pytest.raises(RuntimeError, match="^UNKNOWN libname='unknown-libname'$"): + find_nvidia_headers.find_nvidia_header_directory("unknown-libname") + + hdr_dir = find_nvidia_headers.find_nvidia_header_directory("nvshmem") + # TODO: Find ways to test more meaningfully, and how to avoid HARD-WIRED PATHS in particular. + assert hdr_dir in [ + # pip install nvidia-nvshmem-cu12 + "/home/rgrossekunst/forked/cuda-python/venvs/scratch/lib/python3.12/site-packages/nvidia/nvshmem/include", + # + # conda create -y -n nvshmem python=3.12 + # conda activate nvshmem + # conda install -y conda-forge::libnvshmem3 conda-forge::libnvshmem-dev + "/home/rgrossekunst/miniforge3/envs/nvshmem/include", + # + # sudo apt install libnvshmem3-cuda-12 libnvshmem3-dev-cuda-12 + "/usr/include/nvshmem_12", + # + # nvshmem not available + None, + ] + info_summary_append(f"{hdr_dir=!r}")