Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce cuda.cooperative overloads not requiring temporary storage #2528

Merged
40 changes: 22 additions & 18 deletions python/cuda_cooperative/cuda/cooperative/experimental/_nvrtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from cuda import nvrtc
from cuda.cooperative.experimental._caching import disk_cache
from cuda.cooperative.experimental._common import check_in, version
import importlib.resources as pkg_resources
import importlib
import functools


def CHECK_NVRTC(err, prog):
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
err, logsize = nvrtc.nvrtcGetProgramLogSize(prog)
Expand Down Expand Up @@ -39,23 +40,24 @@ def get_cuda_path():
# rdc is true or false
# code is lto or ptx
# @cache
@functools.lru_cache(maxsize=32) # Always enabled
@disk_cache # Optional, see caching.py
@functools.lru_cache(maxsize=32) # Always enabled
@disk_cache # Optional, see caching.py
def compile_impl(cpp, cc, rdc, code, nvrtc_path, nvrtc_version):
check_in('rdc', rdc, [True, False])
check_in('code', code, ['lto', 'ptx'])

with pkg_resources.path('cuda', '_include') as include_path:
cub_path = include_path
thrust_path = include_path
libcudacxx_path = os.path.join(include_path, 'libcudacxx')
cuda_include_path = os.path.join(get_cuda_path(), 'include')

opts = [b"--std=c++17", \
bytes(f"--include-path={cub_path}", encoding='ascii'), \
bytes(f"--include-path={thrust_path}", encoding='ascii'), \
bytes(f"--include-path={libcudacxx_path}", encoding='ascii'), \
bytes(f"--include-path={cuda_include_path}", encoding='ascii'), \
include_path = importlib.resources.files('cuda').joinpath('_include')
include_path_str = str(include_path)
cub_path = include_path_str
thrust_path = include_path_str
libcudacxx_path = str(os.path.join(include_path, 'libcudacxx'))
cuda_include_path = os.path.join(get_cuda_path(), 'include')

opts = [b"--std=c++17",
bytes(f"--include-path={cub_path}", encoding='ascii'),
bytes(f"--include-path={thrust_path}", encoding='ascii'),
bytes(f"--include-path={libcudacxx_path}", encoding='ascii'),
bytes(f"--include-path={cuda_include_path}", encoding='ascii'),
bytes(f"--gpu-architecture=compute_{cc}", encoding='ascii')]
if rdc:
opts += [b"--relocatable-device-code=true"]
Expand All @@ -67,7 +69,8 @@ def compile_impl(cpp, cc, rdc, code, nvrtc_path, nvrtc_version):
opts += [b"-DCCCL_DISABLE_BF16_SUPPORT"]

# Create program
err, prog = nvrtc.nvrtcCreateProgram(str.encode(cpp), b"code.cu", 0, [], [])
err, prog = nvrtc.nvrtcCreateProgram(
str.encode(cpp), b"code.cu", 0, [], [])
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError(f"nvrtcCreateProgram error: {err}")

Expand Down Expand Up @@ -100,12 +103,13 @@ def compile_impl(cpp, cc, rdc, code, nvrtc_path, nvrtc_version):

return ptx.decode('ascii')


def compile(**kwargs):

err, major, minor = nvrtc.nvrtcVersion()
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise RuntimeError(f"nvrtcVersion error: {err}")
nvrtc_version = version(major, minor)
return nvrtc_version, compile_impl(**kwargs, \
nvrtc_path=nvrtc.__file__, \
nvrtc_version=nvrtc_version)
return nvrtc_version, compile_impl(**kwargs,
nvrtc_path=nvrtc.__file__,
nvrtc_version=nvrtc_version)
Loading
Loading