diff --git a/cuda_core/cuda/core/experimental/_device.py b/cuda_core/cuda/core/experimental/_device.py index ab31b014e..7451f1ddc 100644 --- a/cuda_core/cuda/core/experimental/_device.py +++ b/cuda_core/cuda/core/experimental/_device.py @@ -13,6 +13,7 @@ from cuda.core.experimental._utils.cuda_utils import ( ComputeCapability, CUDAError, + _check_driver_error, driver, handle_return, precondition, @@ -930,6 +931,10 @@ def multicast_supported(self) -> bool: return bool(self._get_cached_attribute(driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED)) +_SUCCESS = driver.CUresult.CUDA_SUCCESS +_INVALID_CTX = driver.CUresult.CUDA_ERROR_INVALID_CONTEXT + + class Device: """Represent a GPU and act as an entry point for cuda.core features. @@ -959,7 +964,7 @@ class Device: __slots__ = ("_id", "_mr", "_has_inited", "_properties") - def __new__(cls, device_id=None): + def __new__(cls, device_id: Optional[int] = None): global _is_cuInit if _is_cuInit is False: with _lock: @@ -968,18 +973,24 @@ def __new__(cls, device_id=None): # important: creating a Device instance does not initialize the GPU! if device_id is None: - device_id = handle_return(runtime.cudaGetDevice()) - assert_type(device_id, int) - else: - total = handle_return(runtime.cudaGetDeviceCount()) - assert_type(device_id, int) - if not (0 <= device_id < total): - raise ValueError(f"device_id must be within [0, {total}), got {device_id}") + err, dev = driver.cuCtxGetDevice() + if err == _SUCCESS: + device_id = int(dev) + elif err == _INVALID_CTX: + ctx = handle_return(driver.cuCtxGetCurrent()) + assert int(ctx) == 0 + device_id = 0 # cudart behavior + else: + _check_driver_error(err) + elif device_id < 0: + raise ValueError(f"device_id must be >= 0, got {device_id}") # ensure Device is singleton - if not hasattr(_tls, "devices"): - total = handle_return(runtime.cudaGetDeviceCount()) - _tls.devices = [] + try: + devices = _tls.devices + except AttributeError: + total = handle_return(driver.cuDeviceGetCount()) + devices = _tls.devices = [] for dev_id in range(total): dev = super().__new__(cls) dev._id = dev_id @@ -987,7 +998,9 @@ def __new__(cls, device_id=None): # use the SynchronousMemoryResource which does not use memory pools. if ( handle_return( - runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrMemoryPoolsSupported, 0) + driver.cuDeviceGetAttribute( + driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, dev_id + ) ) ) == 1: dev._mr = _DefaultAsyncMempool(dev_id) @@ -996,9 +1009,12 @@ def __new__(cls, device_id=None): dev._has_inited = False dev._properties = None - _tls.devices.append(dev) + devices.append(dev) - return _tls.devices[device_id] + try: + return devices[device_id] + except IndexError: + raise ValueError(f"device_id must be within [0, {len(devices)}), got {device_id}") from None def _check_context_initialized(self, *args, **kwargs): if not self._has_inited: diff --git a/cuda_core/docs/source/release/0.3.0-notes.rst b/cuda_core/docs/source/release/0.3.0-notes.rst index 0f8cc77ae..4580cc1ec 100644 --- a/cuda_core/docs/source/release/0.3.0-notes.rst +++ b/cuda_core/docs/source/release/0.3.0-notes.rst @@ -22,7 +22,7 @@ New features - :class:`Kernel` adds :attr:`Kernel.num_arguments` and :attr:`Kernel.arguments_info` for introspection of kernel arguments. (#612) - Add pythonic access to kernel occupancy calculation functions via :attr:`Kernel.occupancy`. (#648) -- Support launching cooperative kernels by setting :property:`LaunchConfig.cooperative_launch` to `True`. +- Support launching cooperative kernels by setting :attr:`LaunchConfig.cooperative_launch` to `True`. - A name can be assigned to :class:`ObjectCode` instances generated by both :class:`Program` and :class:`Linker` through their respective options. @@ -34,5 +34,6 @@ New examples Fixes and enhancements ---------------------- -- An :class:`Event` can now be used to look up its corresponding device and context using the ``.device`` and ``.context`` attributes respectively. +- Look-up of the :attr:`Event.device` and :attr:`Event.context` (the device and CUDA context where an event was created from) is now possible. - The :func:`launch` function's handling of fp16 scalars was incorrect and is fixed. +- The :class:`Device` constructor is made faster.