Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python wrappers for c.parallel scan API #3592

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions python/cuda_parallel/cuda/parallel/experimental/_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@

from cuda.cccl import get_include_paths # type: ignore[import-not-found]

from . import _cccl as cccl


@lru_cache()
def get_bindings() -> ctypes.CDLL:
Expand All @@ -32,19 +30,6 @@ def get_bindings() -> ctypes.CDLL:
else:
raise RuntimeError(f"Unable to locate {so_path}")
_bindings = ctypes.CDLL(str(cccl_c_path))
_bindings.cccl_device_reduce.restype = ctypes.c_int
_bindings.cccl_device_reduce.argtypes = [
cccl.DeviceReduceBuildResult,
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_ulonglong),
cccl.Iterator,
cccl.Iterator,
ctypes.c_ulonglong,
cccl.Op,
cccl.Value,
ctypes.c_void_p,
]
_bindings.cccl_device_reduce_cleanup.restype = ctypes.c_int
return _bindings


Expand Down
8 changes: 5 additions & 3 deletions python/cuda_parallel/cuda/parallel/experimental/_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ class CachableFunction:

def __init__(self, func):
self._func = func

closure = func.__closure__ if func.__closure__ is not None else []
self._identity = (
self._func.__code__.co_code,
self._func.__code__.co_consts,
self._func.__closure__,
func.__code__.co_code,
func.__code__.co_consts,
tuple(cell.cell_contents for cell in closure),
)

def __eq__(self, other):
Expand Down
31 changes: 31 additions & 0 deletions python/cuda_parallel/cuda/parallel/experimental/_cccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import ctypes
import functools
from typing import Callable

import numba
import numpy as np
Expand Down Expand Up @@ -92,6 +93,21 @@ class DeviceReduceBuildResult(ctypes.Structure):
]


# MUST match `cccl_device_scan_build_result_t` in c/include/cccl/c/scan.h
class DeviceScanBuildResult(ctypes.Structure):
_fields_ = [
("cc", ctypes.c_int),
("cubin", ctypes.c_void_p),
("cubin_size", ctypes.c_size_t),
("library", ctypes.c_void_p),
("accumulator_type", TypeInfo),
("init_kernel", ctypes.c_void_p),
("scan_kernel", ctypes.c_void_p),
("description_bytes_per_tile", ctypes.c_size_t),
("payload_bytes_per_tile", ctypes.c_size_t),
]


# MUST match `cccl_value_t` in c/include/cccl/c/types.h
class Value(ctypes.Structure):
_fields_ = [("type", TypeInfo), ("state", ctypes.c_void_p)]
Expand Down Expand Up @@ -223,3 +239,18 @@ def to_cccl_value(array_or_struct: np.ndarray | GpuStruct) -> Value:
else:
# it's a GpuStruct, use the array underlying it
return to_cccl_value(array_or_struct._data)


def to_cccl_op(op: Callable, sig) -> Op:
ltoir, _ = cuda.compile(op, sig=sig, output="ltoir")
name = op.__name__.encode("utf-8")
return Op(
OpKind.STATELESS,
name,
ctypes.c_char_p(ltoir),
len(ltoir),
1,
1,
None,
_data=(ltoir, name), # keep a reference to these in a _data attribute
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .reduce import reduce_into as reduce_into
from .scan import scan as scan
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from __future__ import annotations # TODO: required for Python 3.7 docs env

import ctypes
from functools import cached_property
from typing import Callable

import numba
Expand All @@ -22,28 +21,6 @@
from ..typing import DeviceArrayLike, GpuStruct


class _Op:
def __init__(self, h_init: np.ndarray | GpuStruct, op: Callable):
if isinstance(h_init, np.ndarray):
value_type = numba.from_dtype(h_init.dtype)
else:
value_type = numba.typeof(h_init)
self.ltoir, _ = cuda.compile(op, sig=(value_type, value_type), output="ltoir")
self.name = op.__name__.encode("utf-8")

@cached_property
def handle(self) -> cccl.Op:
return cccl.Op(
cccl.OpKind.STATELESS,
self.name,
ctypes.c_char_p(self.ltoir),
len(self.ltoir),
1,
1,
None,
)


def _dtype_validation(dt1, dt2):
if dt1 != dt2:
raise TypeError(f"dtype mismatch: __init__={dt1}, __call__={dt2}")
Expand All @@ -66,15 +43,19 @@ def __init__(
self.h_init_cccl = cccl.to_cccl_value(h_init)
cc_major, cc_minor = cuda.get_current_device().compute_capability
cub_path, thrust_path, libcudacxx_path, cuda_include_path = get_paths()
bindings = get_bindings()
self.op_wrapper = _Op(h_init, op)
if isinstance(h_init, np.ndarray):
value_type = numba.from_dtype(h_init.dtype)
else:
value_type = numba.typeof(h_init)
sig = (value_type, value_type)
self.op_wrapper = cccl.to_cccl_op(op, sig)
self.build_result = cccl.DeviceReduceBuildResult()
self.bindings = get_bindings()
error = bindings.cccl_device_reduce_build(
error = self.bindings.cccl_device_reduce_build(
ctypes.byref(self.build_result),
self.d_in_cccl,
self.d_out_cccl,
self.op_wrapper.handle,
self.op_wrapper,
cccl.to_cccl_value(h_init),
cc_major,
cc_minor,
Expand Down Expand Up @@ -118,14 +99,14 @@ def __call__(

error = self.bindings.cccl_device_reduce(
self.build_result,
d_temp_storage,
ctypes.c_void_p(d_temp_storage),
ctypes.byref(temp_storage_bytes),
self.d_in_cccl,
self.d_out_cccl,
ctypes.c_ulonglong(num_items),
self.op_wrapper.handle,
self.op_wrapper,
self.h_init_cccl,
stream_handle,
ctypes.c_void_p(stream_handle),
)

if error != enums.CUDA_SUCCESS:
Expand Down Expand Up @@ -164,10 +145,10 @@ def reduce_into(
op: Callable,
h_init: np.ndarray,
):
"""Computes a device-wide reduction using the specified binary ``op`` functor and initial value ``init``.
"""Computes a device-wide reduction using the specified binary ``op`` and initial value ``init``.
Example:
The code snippet below demonstrates the usage of the ``reduce_into`` API:
Below, ``reduce_into`` is used to compute the minimum value of a sequence of integers.
.. literalinclude:: ../../python/cuda_parallel/tests/test_reduce_api.py
:language: python
Expand All @@ -176,9 +157,9 @@ def reduce_into(
:end-before: example-end reduce-min
Args:
d_in: CUDA device array storing the input sequence of data items
d_out: CUDA device array storing the output aggregate
op: Binary reduction
d_in: Device array or iterator containing the input sequence of data items
d_out: Device array (of size 1) that will store the result of the reduction
op: Callable representing the binary operator to apply
init: Numpy array storing initial value of the reduction
Returns:
Expand Down
165 changes: 165 additions & 0 deletions python/cuda_parallel/cuda/parallel/experimental/algorithms/scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations # TODO: required for Python 3.7 docs env

import ctypes
from typing import Callable

import numba
import numpy as np
from numba import cuda
from numba.cuda.cudadrv import enums

from .. import _cccl as cccl
from .._bindings import get_bindings, get_paths
from .._caching import CachableFunction, cache_with_key
from .._utils import protocols
from ..iterators._iterators import IteratorBase
from ..typing import DeviceArrayLike, GpuStruct


class _Scan:
# TODO: constructor shouldn't require concrete `d_in`, `d_out`:
def __init__(
self,
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike | IteratorBase,
op: Callable,
h_init: np.ndarray | GpuStruct,
):
# Referenced from __del__:
self.build_result = None

self.d_in_cccl = cccl.to_cccl_iter(d_in)
self.d_out_cccl = cccl.to_cccl_iter(d_out)
self.h_init_cccl = cccl.to_cccl_value(h_init)
cc_major, cc_minor = cuda.get_current_device().compute_capability
cub_path, thrust_path, libcudacxx_path, cuda_include_path = get_paths()
if isinstance(h_init, np.ndarray):
value_type = numba.from_dtype(h_init.dtype)
else:
value_type = numba.typeof(h_init)
sig = (value_type, value_type)
self.op_wrapper = cccl.to_cccl_op(op, sig)
self.build_result = cccl.DeviceScanBuildResult()
self.bindings = get_bindings()
error = self.bindings.cccl_device_scan_build(
ctypes.byref(self.build_result),
self.d_in_cccl,
self.d_out_cccl,
self.op_wrapper,
cccl.to_cccl_value(h_init),
cc_major,
cc_minor,
ctypes.c_char_p(cub_path),
ctypes.c_char_p(thrust_path),
ctypes.c_char_p(libcudacxx_path),
ctypes.c_char_p(cuda_include_path),
)
if error != enums.CUDA_SUCCESS:
raise ValueError("Error building scan")

def __call__(
self,
temp_storage,
d_in,
d_out,
num_items: int,
h_init: np.ndarray | GpuStruct,
stream=None,
):
if self.d_in_cccl.type.value == cccl.IteratorKind.POINTER:
self.d_in_cccl.state = protocols.get_data_pointer(d_in)
else:
self.d_in_cccl.state = d_in.state

if self.d_out_cccl.type.value == cccl.IteratorKind.POINTER:
self.d_out_cccl.state = protocols.get_data_pointer(d_out)
else:
self.d_out_cccl.state = d_out.state

self.h_init_cccl.state = h_init.__array_interface__["data"][0]

stream_handle = protocols.validate_and_get_stream(stream)

if temp_storage is None:
temp_storage_bytes = ctypes.c_size_t()
d_temp_storage = None
else:
temp_storage_bytes = ctypes.c_size_t(temp_storage.nbytes)
d_temp_storage = protocols.get_data_pointer(temp_storage)

error = self.bindings.cccl_device_scan(
self.build_result,
ctypes.c_void_p(d_temp_storage),
ctypes.byref(temp_storage_bytes),
self.d_in_cccl,
self.d_out_cccl,
ctypes.c_ulonglong(num_items),
self.op_wrapper,
self.h_init_cccl,
ctypes.c_void_p(stream_handle),
)

if error != enums.CUDA_SUCCESS:
raise ValueError("Error reducing")

return temp_storage_bytes.value

def __del__(self):
if self.build_result is None:
return
bindings = get_bindings()
bindings.cccl_device_scan_cleanup(ctypes.byref(self.build_result))


def make_cache_key(
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike | IteratorBase,
op: Callable,
h_init: np.ndarray,
):
d_in_key = (
d_in.kind if isinstance(d_in, IteratorBase) else protocols.get_dtype(d_in)
)
d_out_key = (
d_out.kind if isinstance(d_out, IteratorBase) else protocols.get_dtype(d_out)
)
op_key = CachableFunction(op)
h_init_key = h_init.dtype
return (d_in_key, d_out_key, op_key, h_init_key)


# TODO Figure out `sum` without operator and initial value
# TODO Accept stream
@cache_with_key(make_cache_key)
def scan(
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike | IteratorBase,
op: Callable,
h_init: np.ndarray,
):
"""Computes a device-wide scan using the specified binary ``op`` and initial value ``init``.
Example:
Below, ``scan`` is used to compute an exclusive scan of a sequence of integers.
.. literalinclude:: ../../python/cuda_parallel/tests/test_scan_api.py
:language: python
:dedent:
:start-after: example-begin scan-max
:end-before: example-end scan-max
Args:
d_in: Device array or iterator containing the input sequence of data items
d_out: Device array that will store the result of the scan
op: Callable representing the binary operator to apply
init: Numpy array storing initial value of the scan
Returns:
A callable object that can be used to perform the scan
"""
return _Scan(d_in, d_out, op, h_init)
Loading
Loading