From 247af8c1727dc35f9c572af1be335eeccd2241e3 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Mon, 24 Feb 2025 14:44:28 -0800 Subject: [PATCH 01/25] add IPC support to default async mempool --- cuda_bindings/cuda/bindings/driver.pyx.in | 2 +- cuda_core/cuda/core/experimental/__init__.py | 1 + cuda_core/cuda/core/experimental/_device.py | 4 +- cuda_core/cuda/core/experimental/_memory.py | 324 ++++++++++++++++++- cuda_core/cuda/core/experimental/_utils.py | 10 +- cuda_core/tests/test_memory.py | 307 +++++++++++++++++- 6 files changed, 621 insertions(+), 27 deletions(-) diff --git a/cuda_bindings/cuda/bindings/driver.pyx.in b/cuda_bindings/cuda/bindings/driver.pyx.in index 6be529571..9ef2c4162 100644 --- a/cuda_bindings/cuda/bindings/driver.pyx.in +++ b/cuda_bindings/cuda/bindings/driver.pyx.in @@ -34305,7 +34305,7 @@ def cuMemPrefetchAsync_v2(devPtr, size_t count, location not None : CUmemLocatio to the host NUMA node closest to the current thread's CPU by specifying :py:obj:`~.CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT` for :py:obj:`~.CUmemLocation.type`. Note when - :py:obj:`~.CUmemLocation.type` is etiher + :py:obj:`~.type` is etiher :py:obj:`~.CU_MEM_LOCATION_TYPE_HOST` OR :py:obj:`~.CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT`, :py:obj:`~.CUmemLocation.id` will be ignored. diff --git a/cuda_core/cuda/core/experimental/__init__.py b/cuda_core/cuda/core/experimental/__init__.py index 6e289d49b..6bd5acf5b 100644 --- a/cuda_core/cuda/core/experimental/__init__.py +++ b/cuda_core/cuda/core/experimental/__init__.py @@ -7,6 +7,7 @@ from cuda.core.experimental._event import EventOptions from cuda.core.experimental._launcher import LaunchConfig, launch from cuda.core.experimental._linker import Linker, LinkerOptions +from cuda.core.experimental._memory import AsyncMempool from cuda.core.experimental._module import ObjectCode from cuda.core.experimental._program import Program, ProgramOptions from cuda.core.experimental._stream import Stream, StreamOptions diff --git a/cuda_core/cuda/core/experimental/_device.py b/cuda_core/cuda/core/experimental/_device.py index 0cbd462cd..23868ee48 100644 --- a/cuda_core/cuda/core/experimental/_device.py +++ b/cuda_core/cuda/core/experimental/_device.py @@ -6,7 +6,7 @@ from typing import Union from cuda.core.experimental._context import Context, ContextOptions -from cuda.core.experimental._memory import Buffer, MemoryResource, _DefaultAsyncMempool, _SynchronousMemoryResource +from cuda.core.experimental._memory import AsyncMempool, Buffer, MemoryResource, _SynchronousMemoryResource from cuda.core.experimental._stream import Stream, StreamOptions, default_stream from cuda.core.experimental._utils import ComputeCapability, CUDAError, driver, handle_return, precondition, runtime @@ -962,7 +962,7 @@ def __new__(cls, device_id=None): runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrMemoryPoolsSupported, 0) ) ) == 1: - dev._mr = _DefaultAsyncMempool(dev_id) + dev._mr = AsyncMempool._from_device(dev_id) else: dev._mr = _SynchronousMemoryResource(dev_id) diff --git a/cuda_core/cuda/core/experimental/_memory.py b/cuda_core/cuda/core/experimental/_memory.py index 403ee0842..40cf99cab 100644 --- a/cuda_core/cuda/core/experimental/_memory.py +++ b/cuda_core/cuda/core/experimental/_memory.py @@ -5,6 +5,7 @@ from __future__ import annotations import abc +import platform import weakref from typing import Optional, Tuple, TypeVar @@ -78,6 +79,11 @@ def close(self, stream=None): the default stream. """ + if self._mnff.mr is None: + raise RuntimeError( + "Cannot close a buffer that was not allocated from a memory resource, this buffer is: ", + self, + ) self._mnff.close(stream) @property @@ -204,8 +210,41 @@ def __release_buffer__(self, buffer: memoryview, /): raise NotImplementedError("TODO") +class IPCBuffer(Buffer): + """Buffer class to represent a buffer description which can be shared across processes. + It is not a valid buffer containing data, but rather a description used by the importing + process to construct a valid buffer. It's primary use is to provide a serialization + mechanism for passing exported buffers between processes.""" + + def __init__(self, reserved: bytes, size): + super().__init__(0, 0) + self.reserved = reserved + self._size = size + + def close(self): + raise NotImplementedError("Cannot close an IPC buffer directly") + + def copy_from(self, src, size=None): + raise NotImplementedError("Cannot copy to an IPC buffer") + + def copy_to(self, dst, size=None): + raise NotImplementedError("Cannot copy from an IPC buffer") + + def __reduce__(self): + # This is subject to change if the CumemPoolPtrExportData struct/object changes. + return (self._reconstruct, (self.reserved, self._size)) + + @classmethod + def _reconstruct(cls, reserved, size): + instance = cls(reserved, size) + return instance + + class MemoryResource(abc.ABC): - __slots__ = ("_handle",) + """Base class for memory resources. + + This class provides an abstract interface for memory resources. + """ @abc.abstractmethod def __init__(self, *args, **kwargs): ... @@ -238,36 +277,305 @@ def device_id(self) -> int: ... -class _DefaultAsyncMempool(MemoryResource): - __slots__ = ("_dev_id",) +def _get_platform_handle_type() -> int: + """Returns the appropriate handle type for the current platform.""" + system = platform.system() + if system == "Linux": + return driver.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + elif system == "Windows": + return driver.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_WIN32 + else: + raise RuntimeError(f"Unsupported platform: {system}") - def __init__(self, dev_id): - self._handle = handle_return(driver.cuDeviceGetMemPool(dev_id)) + +class AsyncMempool(MemoryResource): + """A CUDA memory pool for efficient memory allocation. + + This class creates a CUDA memory pool that provides better allocation and + deallocation performance compared to individual allocations. The pool can + optionally be configured to support sharing across process boundaries. + + Use the static methods create() or from_shared_handle() to instantiate. + Direct instantiation is not supported. + + Notes + ----- + The _from_device() method is for internal use by the Device class only and + should not be called directly by users. + """ + + class _MembersNeededForFinalize: + __slots__ = ("handle", "need_close") + + def __init__(self, mr_obj, handle, need_close): + self.handle = handle + self.need_close = need_close + weakref.finalize(mr_obj, self.close) + + def close(self): + if self.handle and self.need_close: + handle_return(driver.cuMemPoolDestroy(self.handle)) + self.handle = None + self.need_close = False + + __slots__ = ("_mnff", "_dev_id", "_ipc_enabled") + + def __init__(self): + """Direct instantiation is not supported. + + Use the static methods create() or from_shared_handle() instead. + """ + raise NotImplementedError( + "directly creating an AsyncMempool object is not supported. Please use either " + "AsyncMempool.create() or from_shared_handle()" + ) + + @staticmethod + def _init(dev_id: int, handle: int, ipc_enabled: bool = False, need_close: bool = False) -> AsyncMempool: + """Internal constructor for AsyncMempool objects. + + Parameters + ---------- + dev_id : int + The ID of the GPU device where the memory pool will be created + handle : int + The handle to the CUDA memory pool + ipc_enabled : bool + Whether the pool supports inter-process sharing capabilities + + Returns + ------- + AsyncMempool + A new memory pool instance + """ + self = AsyncMempool.__new__(AsyncMempool) self._dev_id = dev_id + self._ipc_enabled = ipc_enabled + self._mnff = AsyncMempool._MembersNeededForFinalize(self, handle, need_close) + return self - def allocate(self, size, stream=None) -> Buffer: + @staticmethod + def _from_device(dev_id: int) -> AsyncMempool: + """Internal method to create an AsyncMempool for a device's default memory pool. + + This method is intended for internal use by the Device class only. + Users should not call this method directly. + + Parameters + ---------- + dev_id : int + The ID of the GPU device to get the default memory pool from + + Returns + ------- + AsyncMempool + A memory pool instance connected to the device's default pool + """ + handle = handle_return(driver.cuDeviceGetMemPool(dev_id)) + return AsyncMempool._init(dev_id, handle, ipc_enabled=False, need_close=False) + + @staticmethod + def create(dev_id: int, max_size: int, ipc_enabled: bool = False) -> AsyncMempool: + """Create a new memory pool. + + Parameters + ---------- + dev_id : int + The ID of the GPU device where the memory pool will be created + max_size : int + Maximum size in bytes that the memory pool can grow to + ipc_enabled : bool, optional + Whether to enable inter-process sharing capabilities. Default is False. + + Returns + ------- + AsyncMempool + A new memory pool instance + + Raises + ------ + ValueError + If max_size is None + CUDAError + If pool creation fails + """ + if max_size is None: + raise ValueError("max_size must be provided when creating a new memory pool") + + properties = driver.CUmemPoolProps() + properties.allocType = driver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED + properties.handleTypes = ( + _get_platform_handle_type() if ipc_enabled else driver.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_NONE + ) + properties.location = driver.CUmemLocation() + properties.location.id = dev_id + properties.location.type = driver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + properties.maxSize = max_size + properties.win32SecurityAttributes = 0 + properties.usage = 0 + + handle = handle_return(driver.cuMemPoolCreate(properties)) + return AsyncMempool._init(dev_id, handle, ipc_enabled=ipc_enabled, need_close=True) + + @staticmethod + def from_shared_handle(dev_id: int, shared_handle: int) -> AsyncMempool: + """Create an AsyncMempool from an existing handle. + + Parameters + ---------- + dev_id : int + The ID of the GPU device where the memory pool will be created + shared_handle : int + A platform-specific handle to import an existing memory pool + + Returns + ------- + AsyncMempool + A memory pool instance connected to the existing pool + """ + handle = handle_return(driver.cuMemPoolImportFromShareableHandle(shared_handle, _get_platform_handle_type(), 0)) + return AsyncMempool._init( + dev_id, handle, ipc_enabled=True, need_close=True + ) # Imported pools are always IPC-enabled + + def get_shareable_handle(self) -> int: + """Get a platform-specific handle that can be shared with other processes.""" + if not self._ipc_enabled: + raise RuntimeError("This memory pool was not created with IPC support enabled") + return handle_return(driver.cuMemPoolExportToShareableHandle(self._mnff.handle, _get_platform_handle_type(), 0)) + + def export_buffer(self, buffer: Buffer) -> IPCBuffer: + """Export a buffer allocated from this pool for sharing between processes.""" + if not self._ipc_enabled: + raise RuntimeError("This memory pool was not created with IPC support enabled") + return IPCBuffer(handle_return(driver.cuMemPoolExportPointer(buffer.handle)).reserved, buffer._mnff.size) + + def import_buffer(self, ipc_buffer: IPCBuffer) -> Buffer: + """Import a buffer that was exported from another process.""" + if not self._ipc_enabled: + raise RuntimeError("This memory pool was not created with IPC support enabled") + share_data = driver.CUmemPoolPtrExportData() + share_data.reserved = ipc_buffer.reserved + return Buffer( + handle_return(driver.cuMemPoolImportPointer(self._mnff.handle, share_data)), ipc_buffer._size, self + ) + + def allocate(self, size: int, stream=None) -> Buffer: + """Allocate memory from the pool.""" if stream is None: stream = default_stream() - ptr = handle_return(driver.cuMemAllocFromPoolAsync(size, self._handle, stream.handle)) + ptr = handle_return(driver.cuMemAllocFromPoolAsync(size, self._mnff.handle, stream.handle)) return Buffer(ptr, size, self) - def deallocate(self, ptr, size, stream=None): + def deallocate(self, ptr: int, size: int, stream=None) -> None: + """Deallocate memory back to the pool.""" if stream is None: stream = default_stream() handle_return(driver.cuMemFreeAsync(ptr, stream.handle)) @property def is_device_accessible(self) -> bool: + """Whether memory from this pool is accessible from device code.""" return True @property def is_host_accessible(self) -> bool: + """Whether memory from this pool is accessible from host code.""" return False @property def device_id(self) -> int: + """The ID of the GPU device this memory pool is associated with.""" return self._dev_id + @property + def reuse_follow_event_dependencies(self) -> bool: + """Allow memory to be reused when there are event dependencies between streams.""" + return bool( + handle_return( + driver.cuMemPoolGetAttribute( + self._mnff.handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_REUSE_FOLLOW_EVENT_DEPENDENCIES + ) + ) + ) + + @property + def reuse_allow_opportunistic(self) -> bool: + """Allow reuse of completed frees without dependencies.""" + return bool( + handle_return( + driver.cuMemPoolGetAttribute( + self._mnff.handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC + ) + ) + ) + + @property + def reuse_allow_internal_dependencies(self) -> bool: + """Allow insertion of new stream dependencies for memory reuse.""" + return bool( + handle_return( + driver.cuMemPoolGetAttribute( + self._mnff.handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES + ) + ) + ) + + @property + def release_threshold(self) -> int: + """Amount of reserved memory to hold before OS release.""" + return int( + handle_return( + driver.cuMemPoolGetAttribute( + self._mnff.handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_RELEASE_THRESHOLD + ) + ) + ) + + @property + def reserved_mem_current(self) -> int: + """Current amount of backing memory allocated.""" + return int( + handle_return( + driver.cuMemPoolGetAttribute( + self._mnff.handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT + ) + ) + ) + + @property + def reserved_mem_high(self) -> int: + """High watermark of backing memory allocated.""" + return int( + handle_return( + driver.cuMemPoolGetAttribute( + self._mnff.handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH + ) + ) + ) + + @property + def used_mem_current(self) -> int: + """Current amount of memory in use.""" + return int( + handle_return( + driver.cuMemPoolGetAttribute( + self._mnff.handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_USED_MEM_CURRENT + ) + ) + ) + + @property + def used_mem_high(self) -> int: + """High watermark of memory in use.""" + return int( + handle_return( + driver.cuMemPoolGetAttribute( + self._mnff.handle, driver.CUmemPool_attribute.CU_MEMPOOL_ATTR_USED_MEM_HIGH + ) + ) + ) + class _DefaultPinnedMemorySource(MemoryResource): def __init__(self): diff --git a/cuda_core/cuda/core/experimental/_utils.py b/cuda_core/cuda/core/experimental/_utils.py index 3538ae6c1..2a2889675 100644 --- a/cuda_core/cuda/core/experimental/_utils.py +++ b/cuda_core/cuda/core/experimental/_utils.py @@ -14,6 +14,7 @@ from cuda import cuda as driver from cuda import cudart as runtime from cuda import nvrtc +import traceback class CUDAError(Exception): @@ -35,7 +36,14 @@ def _check_error(error, handle=None): if err == driver.CUresult.CUDA_SUCCESS: err, desc = driver.cuGetErrorString(error) if err == driver.CUresult.CUDA_SUCCESS: - raise CUDAError(f"{name.decode()}: {desc.decode()}") + stack = traceback.extract_stack() + # Get the last 2 frames (excluding the current one) + relevant_stack = stack[-4:-1] + stack_info = "\n".join( + f" File '{frame.filename}', line {frame.lineno}, in {frame.name}\n {frame.line}" + for frame in relevant_stack + ) + raise CUDAError(f"{name.decode()}: {desc.decode()}\n{stack_info}") else: raise CUDAError(f"unknown error: {error}") elif isinstance(error, runtime.cudaError_t): diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index a48db69b5..50eb0d594 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -6,16 +6,23 @@ # this software and related documentation outside the terms of the EULA # is strictly prohibited. +import traceback + +import pytest + try: from cuda.bindings import driver except ImportError: from cuda import cuda as driver +import array import ctypes +import multiprocessing +from socket import AF_UNIX, CMSG_LEN, SCM_RIGHTS, SOCK_DGRAM, SOL_SOCKET, socketpair from cuda.core.experimental import Device -from cuda.core.experimental._memory import Buffer, MemoryResource -from cuda.core.experimental._utils import handle_return +from cuda.core.experimental._memory import AsyncMempool, Buffer, MemoryResource +from cuda.core.experimental._utils import get_binding_version, handle_return class DummyDeviceMemoryResource(MemoryResource): @@ -47,12 +54,10 @@ def __init__(self): pass def allocate(self, size, stream=None) -> Buffer: - # Allocate a ctypes buffer of size `size` ptr = (ctypes.c_byte * size)() return Buffer(ptr=ptr, size=size, mr=self) def deallocate(self, ptr, size, stream=None): - # the memory is deallocated per the ctypes deallocation at garbage collection time pass @property @@ -116,10 +121,11 @@ def device_id(self) -> int: raise RuntimeError("the pinned memory resource is not bound to any GPU") +# Basic Buffer Tests def buffer_initialization(dummy_mr: MemoryResource): - buffer = dummy_mr.allocate(size=1024) + buffer = dummy_mr.allocate(size=64) assert buffer.handle != 0 - assert buffer.size == 1024 + assert buffer.size == 64 assert buffer.memory_resource == dummy_mr assert buffer.is_device_accessible == dummy_mr.is_device_accessible assert buffer.is_host_accessible == dummy_mr.is_host_accessible @@ -136,13 +142,13 @@ def test_buffer_initialization(): def buffer_copy_to(dummy_mr: MemoryResource, device: Device, check=False): - src_buffer = dummy_mr.allocate(size=1024) - dst_buffer = dummy_mr.allocate(size=1024) + src_buffer = dummy_mr.allocate(size=64) + dst_buffer = dummy_mr.allocate(size=64) stream = device.create_stream() if check: src_ptr = ctypes.cast(src_buffer.handle, ctypes.POINTER(ctypes.c_byte)) - for i in range(1024): + for i in range(64): src_ptr[i] = ctypes.c_byte(i) src_buffer.copy_to(dst_buffer, stream=stream) @@ -150,7 +156,6 @@ def buffer_copy_to(dummy_mr: MemoryResource, device: Device, check=False): if check: dst_ptr = ctypes.cast(dst_buffer.handle, ctypes.POINTER(ctypes.c_byte)) - for i in range(10): assert dst_ptr[i] == src_ptr[i] @@ -167,13 +172,13 @@ def test_buffer_copy_to(): def buffer_copy_from(dummy_mr: MemoryResource, device, check=False): - src_buffer = dummy_mr.allocate(size=1024) - dst_buffer = dummy_mr.allocate(size=1024) + src_buffer = dummy_mr.allocate(size=64) + dst_buffer = dummy_mr.allocate(size=64) stream = device.create_stream() if check: src_ptr = ctypes.cast(src_buffer.handle, ctypes.POINTER(ctypes.c_byte)) - for i in range(1024): + for i in range(64): src_ptr[i] = ctypes.c_byte(i) dst_buffer.copy_from(src_buffer, stream=stream) @@ -181,7 +186,6 @@ def buffer_copy_from(dummy_mr: MemoryResource, device, check=False): if check: dst_ptr = ctypes.cast(dst_buffer.handle, ctypes.POINTER(ctypes.c_byte)) - for i in range(10): assert dst_ptr[i] == src_ptr[i] @@ -198,7 +202,7 @@ def test_buffer_copy_from(): def buffer_close(dummy_mr: MemoryResource): - buffer = dummy_mr.allocate(size=1024) + buffer = dummy_mr.allocate(size=64) buffer.close() assert buffer.handle == 0 assert buffer.memory_resource is None @@ -211,3 +215,276 @@ def test_buffer_close(): buffer_close(DummyHostMemoryResource()) buffer_close(DummyUnifiedMemoryResource(device)) buffer_close(DummyPinnedMemoryResource(device)) + + +def test_mempool(): + if get_binding_version() < (12, 0): + pytest.skip("Test requires CUDA 12 or higher") + device = Device() + device.set_current() + pool_size = 2097152 # 2MB size + + # Test basic pool creation + mr = AsyncMempool.create(device.device_id, pool_size, enable_ipc=False) + assert mr.device_id == device.device_id + assert mr.is_device_accessible + assert not mr.is_host_accessible + + # Test allocation and deallocation + buffer1 = mr.allocate(1024) + assert buffer1.handle != 0 + assert buffer1.size == 1024 + assert buffer1.memory_resource == mr + buffer1.close() + + # Test multiple allocations + buffer1 = mr.allocate(1024) + buffer2 = mr.allocate(2048) + assert buffer1.handle != buffer2.handle + assert buffer1.size == 1024 + assert buffer2.size == 2048 + buffer1.close() + buffer2.close() + + # Test stream-based allocation + stream = device.create_stream() + buffer = mr.allocate(1024, stream=stream) + assert buffer.handle != 0 + buffer.close() + + # Test memory copying between buffers from same pool + src_buffer = mr.allocate(64) + dst_buffer = mr.allocate(64) + stream = device.create_stream() + src_buffer.copy_to(dst_buffer, stream=stream) + device.sync() + dst_buffer.close() + src_buffer.close() + + # Test error cases + with pytest.raises(NotImplementedError, match="directly creating a AsyncMempool object is not supported"): + AsyncMempool() + + with pytest.raises(ValueError, match="max_size must be provided when creating a new memory pool"): + AsyncMempool.create(device.device_id, None) + + # Test IPC operations are disabled + buffer = mr.allocate(64) + + with pytest.raises(RuntimeError, match="This memory pool was not created with IPC support enabled"): + mr.get_shareable_handle() + + with pytest.raises(RuntimeError, match="This memory pool was not created with IPC support enabled"): + mr.export_buffer(buffer) + + with pytest.raises(RuntimeError, match="This memory pool was not created with IPC support enabled"): + mr.import_buffer(None) + + buffer.close() + + +@pytest.mark.parametrize( + "property_name,expected_type", + [ + ("reuse_follow_event_dependencies", bool), + ("reuse_allow_opportunistic", bool), + ("reuse_allow_internal_dependencies", bool), + ("release_threshold", int), + ("reserved_mem_current", int), + ("reserved_mem_high", int), + ("used_mem_current", int), + ("used_mem_high", int), + ], +) +def test_mempool_properties(property_name, expected_type): + """Test all properties of the AsyncMempool class.""" + # skip test if cuda version is less than 12 + if get_binding_version() < (12, 0): + pytest.skip("Test requires CUDA 12 or higher") + + device = Device() + device.set_current() + pool_size = 2097152 # 2MB size + mr = AsyncMempool.create(device.device_id, pool_size, enable_ipc=False) + + try: + # Get the property value + value = getattr(mr, property_name) + + # Test type + assert isinstance(value, expected_type), f"{property_name} should return {expected_type}, got {type(value)}" + + # Test value constraints + if expected_type is int: + assert value >= 0, f"{property_name} should be non-negative" + + # Test memory usage properties with actual allocations + if property_name in ["reserved_mem_current", "used_mem_current"]: + # Allocate some memory and check if values increase + initial_value = value + buffer = None + try: + buffer = mr.allocate(1024) + new_value = getattr(mr, property_name) + assert new_value >= initial_value, f"{property_name} should increase or stay same after allocation" + finally: + if buffer is not None: + buffer.close() + + # Test high watermark properties + if property_name in ["reserved_mem_high", "used_mem_high"]: + # High watermark should never be less than current + current_prop = property_name.replace("_high", "_current") + current_value = getattr(mr, current_prop) + assert value >= current_value, f"{property_name} should be >= {current_prop}" + + finally: + # Ensure we allocate and deallocate a small buffer to flush any pending operations + flush_buffer = mr.allocate(64) + flush_buffer.close() + + +def mempool_child_process(importer, queue): + try: + device = Device() + device.set_current() + stream = device.create_stream() + + # Receive the handle via socket + fds = array.array("i") + _, ancdata, _, _ = importer.recvmsg(0, CMSG_LEN(fds.itemsize)) + assert len(ancdata) == 1 + cmsg_level, cmsg_type, cmsg_data = ancdata[0] + assert cmsg_level == SOL_SOCKET and cmsg_type == SCM_RIGHTS + fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + shared_handle = int(fds[0]) + + mr = AsyncMempool.from_shared_handle(device.device_id, shared_handle) + ipc_buffer = queue.get() # Get exported buffer data + buffer = mr.import_buffer(ipc_buffer) + + # Create a new buffer to verify data using unified memory + unified_mr = DummyUnifiedMemoryResource(device) + verify_buffer = unified_mr.allocate(64) + + # Copy data from imported buffer to verify contents + verify_buffer.copy_from(buffer, stream=stream) + device.sync() + + # Verify the data matches expected pattern + verify_ptr = ctypes.cast(int(verify_buffer.handle), ctypes.POINTER(ctypes.c_byte)) + for i in range(64): + assert ctypes.c_byte(verify_ptr[i]).value == ctypes.c_byte(i).value, f"Data mismatch at index {i}" + + # Write new pattern to the buffer using unified memory + src_buffer = unified_mr.allocate(64) + src_ptr = ctypes.cast(int(src_buffer.handle), ctypes.POINTER(ctypes.c_byte)) + for i in range(64): + src_ptr[i] = ctypes.c_byte(255 - i) # Write inverted pattern + + # Copy new pattern to the IPC buffer + buffer.copy_from(src_buffer, stream=stream) + device.sync() + + verify_buffer.close() + src_buffer.close() + buffer.close() + + queue.put(True) + except Exception as e: + # Capture the full traceback + tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__)) + queue.put((e, tb_str)) + + +def test_ipc_mempool(): + if get_binding_version() < (12, 0): + pytest.skip("Test requires CUDA 12 or higher") + # Set multiprocessing start method before creating any multiprocessing objects + multiprocessing.set_start_method("spawn", force=True) + + device = Device() + device.set_current() + stream = device.create_stream() + pool_size = 2097152 # 2MB size + mr = AsyncMempool.create(device.device_id, pool_size, enable_ipc=True) + + # Create socket pair for handle transfer + exporter, importer = socketpair(AF_UNIX, SOCK_DGRAM) + queue = multiprocessing.Queue() + process = None + + try: + shareable_handle = mr.get_shareable_handle() + + # Allocate and export memory + buffer = mr.allocate(64) + + try: + # Fill buffer with test pattern using unified memory + unified_mr = DummyUnifiedMemoryResource(device) + src_buffer = unified_mr.allocate(64) + try: + src_ptr = ctypes.cast(int(src_buffer.handle), ctypes.POINTER(ctypes.c_byte)) + for i in range(64): + src_ptr[i] = ctypes.c_byte(i) + + buffer.copy_from(src_buffer, stream=stream) + device.sync() + finally: + src_buffer.close() + + # Export buffer for IPC + ipc_buffer = mr.export_buffer(buffer) + + # Start child process + process = multiprocessing.Process(target=mempool_child_process, args=(importer, queue)) + process.start() + + # Send handles to child process + exporter.sendmsg([], [(SOL_SOCKET, SCM_RIGHTS, array.array("i", [shareable_handle]))]) + queue.put(ipc_buffer) + + # Wait for child process + process.join(timeout=10) + assert process.exitcode == 0 + + # Check for exceptions + if not queue.empty(): + result = queue.get() + if isinstance(result, tuple): + exception, traceback_str = result + print("\nException in child process:") + print(traceback_str) + raise exception + assert result is True + + # Verify child process wrote the inverted pattern using unified memory + verify_buffer = unified_mr.allocate(64) + try: + verify_buffer.copy_from(buffer, stream=stream) + device.sync() + + verify_ptr = ctypes.cast(int(verify_buffer.handle), ctypes.POINTER(ctypes.c_byte)) + for i in range(64): + assert ( + ctypes.c_byte(verify_ptr[i]).value == ctypes.c_byte(255 - i).value + ), f"Child process data not reflected in parent at index {i}" + finally: + verify_buffer.close() + + finally: + buffer.close() + + finally: + # Clean up all resources + if process is not None and process.is_alive(): + process.terminate() + process.join(timeout=1) + queue.close() + queue.join_thread() # Ensure the queue's background thread is cleaned up + exporter.close() + importer.close() + # Flush any pending operations + flush_buffer = mr.allocate(64) + flush_buffer.close() From 4b521519438913c508a8b65f9480e9537d5539a6 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Mon, 24 Feb 2025 14:47:01 -0800 Subject: [PATCH 02/25] revert utils --- cuda_core/cuda/core/experimental/_utils.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_utils.py b/cuda_core/cuda/core/experimental/_utils.py index 2a2889675..3538ae6c1 100644 --- a/cuda_core/cuda/core/experimental/_utils.py +++ b/cuda_core/cuda/core/experimental/_utils.py @@ -14,7 +14,6 @@ from cuda import cuda as driver from cuda import cudart as runtime from cuda import nvrtc -import traceback class CUDAError(Exception): @@ -36,14 +35,7 @@ def _check_error(error, handle=None): if err == driver.CUresult.CUDA_SUCCESS: err, desc = driver.cuGetErrorString(error) if err == driver.CUresult.CUDA_SUCCESS: - stack = traceback.extract_stack() - # Get the last 2 frames (excluding the current one) - relevant_stack = stack[-4:-1] - stack_info = "\n".join( - f" File '{frame.filename}', line {frame.lineno}, in {frame.name}\n {frame.line}" - for frame in relevant_stack - ) - raise CUDAError(f"{name.decode()}: {desc.decode()}\n{stack_info}") + raise CUDAError(f"{name.decode()}: {desc.decode()}") else: raise CUDAError(f"unknown error: {error}") elif isinstance(error, runtime.cudaError_t): From 050e4d6b02d79ad428d4e2f7cd9d58fe0d1854c4 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Mon, 24 Feb 2025 14:47:21 -0800 Subject: [PATCH 03/25] revert utils --- cuda_core/cuda/core/experimental/_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/cuda_core/cuda/core/experimental/_utils.py b/cuda_core/cuda/core/experimental/_utils.py index 3538ae6c1..2a2889675 100644 --- a/cuda_core/cuda/core/experimental/_utils.py +++ b/cuda_core/cuda/core/experimental/_utils.py @@ -14,6 +14,7 @@ from cuda import cuda as driver from cuda import cudart as runtime from cuda import nvrtc +import traceback class CUDAError(Exception): @@ -35,7 +36,14 @@ def _check_error(error, handle=None): if err == driver.CUresult.CUDA_SUCCESS: err, desc = driver.cuGetErrorString(error) if err == driver.CUresult.CUDA_SUCCESS: - raise CUDAError(f"{name.decode()}: {desc.decode()}") + stack = traceback.extract_stack() + # Get the last 2 frames (excluding the current one) + relevant_stack = stack[-4:-1] + stack_info = "\n".join( + f" File '{frame.filename}', line {frame.lineno}, in {frame.name}\n {frame.line}" + for frame in relevant_stack + ) + raise CUDAError(f"{name.decode()}: {desc.decode()}\n{stack_info}") else: raise CUDAError(f"unknown error: {error}") elif isinstance(error, runtime.cudaError_t): From 1da30fa94d3517fb7d2c9e057e30bda6113f7841 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Mon, 24 Feb 2025 14:48:30 -0800 Subject: [PATCH 04/25] rever utils --- cuda_core/cuda/core/experimental/_utils.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_utils.py b/cuda_core/cuda/core/experimental/_utils.py index 2a2889675..3538ae6c1 100644 --- a/cuda_core/cuda/core/experimental/_utils.py +++ b/cuda_core/cuda/core/experimental/_utils.py @@ -14,7 +14,6 @@ from cuda import cuda as driver from cuda import cudart as runtime from cuda import nvrtc -import traceback class CUDAError(Exception): @@ -36,14 +35,7 @@ def _check_error(error, handle=None): if err == driver.CUresult.CUDA_SUCCESS: err, desc = driver.cuGetErrorString(error) if err == driver.CUresult.CUDA_SUCCESS: - stack = traceback.extract_stack() - # Get the last 2 frames (excluding the current one) - relevant_stack = stack[-4:-1] - stack_info = "\n".join( - f" File '{frame.filename}', line {frame.lineno}, in {frame.name}\n {frame.line}" - for frame in relevant_stack - ) - raise CUDAError(f"{name.decode()}: {desc.decode()}\n{stack_info}") + raise CUDAError(f"{name.decode()}: {desc.decode()}") else: raise CUDAError(f"unknown error: {error}") elif isinstance(error, runtime.cudaError_t): From ebef4388ba77e92a5f54ceaf0181a1b644d0719c Mon Sep 17 00:00:00 2001 From: ksimpson Date: Tue, 25 Feb 2025 09:26:36 -0800 Subject: [PATCH 05/25] fix typo --- cuda_core/cuda/core/experimental/_utils.py | 10 +++++++++- cuda_core/tests/test_memory.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_utils.py b/cuda_core/cuda/core/experimental/_utils.py index 3538ae6c1..2a2889675 100644 --- a/cuda_core/cuda/core/experimental/_utils.py +++ b/cuda_core/cuda/core/experimental/_utils.py @@ -14,6 +14,7 @@ from cuda import cuda as driver from cuda import cudart as runtime from cuda import nvrtc +import traceback class CUDAError(Exception): @@ -35,7 +36,14 @@ def _check_error(error, handle=None): if err == driver.CUresult.CUDA_SUCCESS: err, desc = driver.cuGetErrorString(error) if err == driver.CUresult.CUDA_SUCCESS: - raise CUDAError(f"{name.decode()}: {desc.decode()}") + stack = traceback.extract_stack() + # Get the last 2 frames (excluding the current one) + relevant_stack = stack[-4:-1] + stack_info = "\n".join( + f" File '{frame.filename}', line {frame.lineno}, in {frame.name}\n {frame.line}" + for frame in relevant_stack + ) + raise CUDAError(f"{name.decode()}: {desc.decode()}\n{stack_info}") else: raise CUDAError(f"unknown error: {error}") elif isinstance(error, runtime.cudaError_t): diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index 50eb0d594..2e95aab0f 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -407,7 +407,7 @@ def test_ipc_mempool(): device.set_current() stream = device.create_stream() pool_size = 2097152 # 2MB size - mr = AsyncMempool.create(device.device_id, pool_size, enable_ipc=True) + mr = AsyncMempool.create(device.device_id, pool_size, ipc_enabled=True) # Create socket pair for handle transfer exporter, importer = socketpair(AF_UNIX, SOCK_DGRAM) From 5023339cd751327cdf07342ce832006ef9fda978 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Tue, 25 Feb 2025 10:13:09 -0800 Subject: [PATCH 06/25] fix typo --- cuda_core/tests/test_memory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index 2e95aab0f..64b8ed9e2 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -225,7 +225,7 @@ def test_mempool(): pool_size = 2097152 # 2MB size # Test basic pool creation - mr = AsyncMempool.create(device.device_id, pool_size, enable_ipc=False) + mr = AsyncMempool.create(device.device_id, pool_size, ipc_enabled=False) assert mr.device_id == device.device_id assert mr.is_device_accessible assert not mr.is_host_accessible @@ -305,7 +305,7 @@ def test_mempool_properties(property_name, expected_type): device = Device() device.set_current() pool_size = 2097152 # 2MB size - mr = AsyncMempool.create(device.device_id, pool_size, enable_ipc=False) + mr = AsyncMempool.create(device.device_id, pool_size, ipc_enabled=False) try: # Get the property value From e4da6337537a844367303725f5e42933d929f810 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Tue, 25 Feb 2025 10:34:46 -0800 Subject: [PATCH 07/25] another typo --- cuda_core/tests/test_memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index 64b8ed9e2..753d39a58 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -262,7 +262,7 @@ def test_mempool(): src_buffer.close() # Test error cases - with pytest.raises(NotImplementedError, match="directly creating a AsyncMempool object is not supported"): + with pytest.raises(NotImplementedError, match="directly creating an AsyncMempool object is not supported"): AsyncMempool() with pytest.raises(ValueError, match="max_size must be provided when creating a new memory pool"): From 77678e969f9b306f4bec112f5c8514bb3d5cac88 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Tue, 25 Feb 2025 13:40:00 -0800 Subject: [PATCH 08/25] make IPC buffer a descriptor not a buffer --- cuda_core/cuda/core/experimental/_memory.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_memory.py b/cuda_core/cuda/core/experimental/_memory.py index 40cf99cab..52ce6fe85 100644 --- a/cuda_core/cuda/core/experimental/_memory.py +++ b/cuda_core/cuda/core/experimental/_memory.py @@ -210,7 +210,7 @@ def __release_buffer__(self, buffer: memoryview, /): raise NotImplementedError("TODO") -class IPCBuffer(Buffer): +class IPCBufferDescriptor: """Buffer class to represent a buffer description which can be shared across processes. It is not a valid buffer containing data, but rather a description used by the importing process to construct a valid buffer. It's primary use is to provide a serialization @@ -221,15 +221,6 @@ def __init__(self, reserved: bytes, size): self.reserved = reserved self._size = size - def close(self): - raise NotImplementedError("Cannot close an IPC buffer directly") - - def copy_from(self, src, size=None): - raise NotImplementedError("Cannot copy to an IPC buffer") - - def copy_to(self, dst, size=None): - raise NotImplementedError("Cannot copy from an IPC buffer") - def __reduce__(self): # This is subject to change if the CumemPoolPtrExportData struct/object changes. return (self._reconstruct, (self.reserved, self._size)) @@ -444,13 +435,15 @@ def get_shareable_handle(self) -> int: raise RuntimeError("This memory pool was not created with IPC support enabled") return handle_return(driver.cuMemPoolExportToShareableHandle(self._mnff.handle, _get_platform_handle_type(), 0)) - def export_buffer(self, buffer: Buffer) -> IPCBuffer: + def export_buffer(self, buffer: Buffer) -> IPCBufferDescriptor: """Export a buffer allocated from this pool for sharing between processes.""" if not self._ipc_enabled: raise RuntimeError("This memory pool was not created with IPC support enabled") - return IPCBuffer(handle_return(driver.cuMemPoolExportPointer(buffer.handle)).reserved, buffer._mnff.size) + return IPCBufferDescriptor( + handle_return(driver.cuMemPoolExportPointer(buffer.handle)).reserved, buffer._mnff.size + ) - def import_buffer(self, ipc_buffer: IPCBuffer) -> Buffer: + def import_buffer(self, ipc_buffer: IPCBufferDescriptor) -> Buffer: """Import a buffer that was exported from another process.""" if not self._ipc_enabled: raise RuntimeError("This memory pool was not created with IPC support enabled") From 1711f7800cca1dd4f295f567c529f92cc9484651 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Tue, 25 Feb 2025 13:47:19 -0800 Subject: [PATCH 09/25] support windows --- cuda_core/tests/test_memory.py | 63 ++++++++++++++++++++++++---------- 1 file changed, 45 insertions(+), 18 deletions(-) diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index 753d39a58..b8fa03c1c 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -6,6 +6,7 @@ # this software and related documentation outside the terms of the EULA # is strictly prohibited. +import platform import traceback import pytest @@ -18,7 +19,11 @@ import array import ctypes import multiprocessing -from socket import AF_UNIX, CMSG_LEN, SCM_RIGHTS, SOCK_DGRAM, SOL_SOCKET, socketpair + +if platform.system() == "Linux": + from socket import AF_UNIX, CMSG_LEN, SCM_RIGHTS, SOCK_DGRAM, SOL_SOCKET, socketpair + + pass from cuda.core.experimental import Device from cuda.core.experimental._memory import AsyncMempool, Buffer, MemoryResource @@ -350,14 +355,18 @@ def mempool_child_process(importer, queue): device.set_current() stream = device.create_stream() - # Receive the handle via socket - fds = array.array("i") - _, ancdata, _, _ = importer.recvmsg(0, CMSG_LEN(fds.itemsize)) - assert len(ancdata) == 1 - cmsg_level, cmsg_type, cmsg_data = ancdata[0] - assert cmsg_level == SOL_SOCKET and cmsg_type == SCM_RIGHTS - fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) - shared_handle = int(fds[0]) + # Get the shared handle differently based on platform + if platform.system() == "Windows": + shared_handle = queue.get() # On Windows, we pass the handle through the queue + else: + # Unix socket handle transfer + fds = array.array("i") + _, ancdata, _, _ = importer.recvmsg(0, CMSG_LEN(fds.itemsize)) + assert len(ancdata) == 1 + cmsg_level, cmsg_type, cmsg_data = ancdata[0] + assert cmsg_level == SOL_SOCKET and cmsg_type == SCM_RIGHTS + fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + shared_handle = int(fds[0]) mr = AsyncMempool.from_shared_handle(device.device_id, shared_handle) ipc_buffer = queue.get() # Get exported buffer data @@ -400,17 +409,26 @@ def mempool_child_process(importer, queue): def test_ipc_mempool(): if get_binding_version() < (12, 0): pytest.skip("Test requires CUDA 12 or higher") - # Set multiprocessing start method before creating any multiprocessing objects - multiprocessing.set_start_method("spawn", force=True) + # Check if IPC is supported on this platform/device device = Device() device.set_current() + if not device.properties.memory_pools_supported: + pytest.skip("Device does not support mempool operations") + + # Set multiprocessing start method before creating any multiprocessing objects + multiprocessing.set_start_method("spawn", force=True) + stream = device.create_stream() pool_size = 2097152 # 2MB size mr = AsyncMempool.create(device.device_id, pool_size, ipc_enabled=True) - # Create socket pair for handle transfer - exporter, importer = socketpair(AF_UNIX, SOCK_DGRAM) + # Create socket pair for handle transfer (only on Unix systems) + exporter = None + importer = None + if platform.system() == "Linux": + exporter, importer = socketpair(AF_UNIX, SOCK_DGRAM) + queue = multiprocessing.Queue() process = None @@ -438,11 +456,18 @@ def test_ipc_mempool(): ipc_buffer = mr.export_buffer(buffer) # Start child process - process = multiprocessing.Process(target=mempool_child_process, args=(importer, queue)) + process = multiprocessing.Process( + target=mempool_child_process, args=(importer if platform.system() == "Linux" else None, queue) + ) process.start() # Send handles to child process - exporter.sendmsg([], [(SOL_SOCKET, SCM_RIGHTS, array.array("i", [shareable_handle]))]) + if platform.system() == "Windows": + queue.put(shareable_handle) # Send handle through queue on Windows + else: + # Use Unix socket for handle transfer + exporter.sendmsg([], [(SOL_SOCKET, SCM_RIGHTS, array.array("i", [shareable_handle]))]) + queue.put(ipc_buffer) # Wait for child process @@ -482,9 +507,11 @@ def test_ipc_mempool(): process.terminate() process.join(timeout=1) queue.close() - queue.join_thread() # Ensure the queue's background thread is cleaned up - exporter.close() - importer.close() + queue.join_thread() + if exporter is not None: + exporter.close() + if importer is not None: + importer.close() # Flush any pending operations flush_buffer = mr.allocate(64) flush_buffer.close() From 4f25290b3ae444754548e5c3e8780562c9c9b478 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Wed, 26 Feb 2025 11:28:01 -0800 Subject: [PATCH 10/25] remove super call --- cuda_core/cuda/core/experimental/_memory.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_memory.py b/cuda_core/cuda/core/experimental/_memory.py index 52ce6fe85..dd41158e4 100644 --- a/cuda_core/cuda/core/experimental/_memory.py +++ b/cuda_core/cuda/core/experimental/_memory.py @@ -216,8 +216,7 @@ class IPCBufferDescriptor: process to construct a valid buffer. It's primary use is to provide a serialization mechanism for passing exported buffers between processes.""" - def __init__(self, reserved: bytes, size): - super().__init__(0, 0) + def __init__(self, reserved: bytes, size: int): self.reserved = reserved self._size = size From 24a5652108defe371ab30ddc73d12eaa131282b5 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Wed, 26 Feb 2025 11:57:24 -0800 Subject: [PATCH 11/25] add security handle with defualt no security for now --- cuda_core/cuda/core/experimental/_memory.py | 56 ++++++++++++++++++++- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_memory.py b/cuda_core/cuda/core/experimental/_memory.py index dd41158e4..4ae3289a8 100644 --- a/cuda_core/cuda/core/experimental/_memory.py +++ b/cuda_core/cuda/core/experimental/_memory.py @@ -5,6 +5,9 @@ from __future__ import annotations import abc + +# Add ctypes import for Windows security attributes +import ctypes import platform import weakref from typing import Optional, Tuple, TypeVar @@ -278,6 +281,41 @@ def _get_platform_handle_type() -> int: raise RuntimeError(f"Unsupported platform: {system}") +def _create_win32_security_attributes(): + """Creates a Windows SECURITY_ATTRIBUTES structure with default settings. + + Returns: + A pointer to a SECURITY_ATTRIBUTES structure or None if not on Windows. + """ + if platform.system() != "Windows": + return None + + # Define the Windows SECURITY_ATTRIBUTES structure + class SECURITY_ATTRIBUTES(ctypes.Structure): + _fields_ = [ + ("nLength", ctypes.c_ulong), + ("lpSecurityDescriptor", ctypes.c_void_p), + ("bInheritHandle", ctypes.c_int), + ] + + # Create a new security descriptor + security_descriptor = ctypes.windll.advapi32.LocalAlloc(0, 0) + + # Initialize the security descriptor (empty one with no security) + if not ctypes.windll.advapi32.InitializeSecurityDescriptor(security_descriptor, 1): + ctypes.windll.kernel32.LocalFree(security_descriptor) + return None + + # Create and initialize the security attributes structure + sa = SECURITY_ATTRIBUTES() + sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) + sa.lpSecurityDescriptor = security_descriptor + sa.bInheritHandle = 0 # Don't inherit handle + + # Return a pointer that can be passed to the CUDA API + return ctypes.addressof(sa) + + class AsyncMempool(MemoryResource): """A CUDA memory pool for efficient memory allocation. @@ -365,7 +403,9 @@ def _from_device(dev_id: int) -> AsyncMempool: return AsyncMempool._init(dev_id, handle, ipc_enabled=False, need_close=False) @staticmethod - def create(dev_id: int, max_size: int, ipc_enabled: bool = False) -> AsyncMempool: + def create( + dev_id: int, max_size: int, ipc_enabled: bool = False, win32_security_attributes: int = 0 + ) -> AsyncMempool: """Create a new memory pool. Parameters @@ -376,6 +416,9 @@ def create(dev_id: int, max_size: int, ipc_enabled: bool = False) -> AsyncMempoo Maximum size in bytes that the memory pool can grow to ipc_enabled : bool, optional Whether to enable inter-process sharing capabilities. Default is False. + win32_security_attributes : int, optional + Custom Windows security attributes pointer. If 0 (default), a default security + attributes structure will be created when needed on Windows platforms. Returns ------- @@ -401,7 +444,16 @@ def create(dev_id: int, max_size: int, ipc_enabled: bool = False) -> AsyncMempoo properties.location.id = dev_id properties.location.type = driver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE properties.maxSize = max_size - properties.win32SecurityAttributes = 0 + + # Set up Windows security attributes if needed + if platform.system() == "Windows" and ipc_enabled: + if win32_security_attributes == 0: + # Create default security attributes if none provided + win32_security_attributes = _create_win32_security_attributes() + properties.win32SecurityAttributes = win32_security_attributes + else: + properties.win32SecurityAttributes = 0 + properties.usage = 0 handle = handle_return(driver.cuMemPoolCreate(properties)) From da08473477b353913608bbebccc1d5a7fd81a7cf Mon Sep 17 00:00:00 2001 From: ksimpson Date: Wed, 26 Feb 2025 12:07:04 -0800 Subject: [PATCH 12/25] push for local test --- cuda_core/cuda/core/experimental/_memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda_core/cuda/core/experimental/_memory.py b/cuda_core/cuda/core/experimental/_memory.py index 4ae3289a8..7292e66ad 100644 --- a/cuda_core/cuda/core/experimental/_memory.py +++ b/cuda_core/cuda/core/experimental/_memory.py @@ -308,7 +308,7 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): # Create and initialize the security attributes structure sa = SECURITY_ATTRIBUTES() - sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) + sa.nLength = 1 # ctypes.sizeof(SECURITY_ATTRIBUTES) sa.lpSecurityDescriptor = security_descriptor sa.bInheritHandle = 0 # Don't inherit handle From 9e1d546b6ccb2c66dac69a9f2cdcdb86266f6270 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Wed, 26 Feb 2025 12:15:56 -0800 Subject: [PATCH 13/25] switch dll call --- cuda_core/cuda/core/experimental/_memory.py | 54 ++++++++++++++++++--- 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_memory.py b/cuda_core/cuda/core/experimental/_memory.py index 7292e66ad..7530f4f6c 100644 --- a/cuda_core/cuda/core/experimental/_memory.py +++ b/cuda_core/cuda/core/experimental/_memory.py @@ -6,6 +6,9 @@ import abc +# Register cleanup function to be called at interpreter shutdown +import atexit + # Add ctypes import for Windows security attributes import ctypes import platform @@ -284,6 +287,9 @@ def _get_platform_handle_type() -> int: def _create_win32_security_attributes(): """Creates a Windows SECURITY_ATTRIBUTES structure with default settings. + The security descriptor is configured with a DACL that allows access to everyone, + which is appropriate for shared memory that needs to be accessible across processes. + Returns: A pointer to a SECURITY_ATTRIBUTES structure or None if not on Windows. """ @@ -298,24 +304,60 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): ("bInheritHandle", ctypes.c_int), ] - # Create a new security descriptor - security_descriptor = ctypes.windll.advapi32.LocalAlloc(0, 0) + # Constants for security descriptor creation + SECURITY_DESCRIPTOR_REVISION = 1 + SECURITY_DESCRIPTOR_MIN_LENGTH = 1024 - # Initialize the security descriptor (empty one with no security) - if not ctypes.windll.advapi32.InitializeSecurityDescriptor(security_descriptor, 1): + # Create a new security descriptor - use kernel32 for memory allocation + # LPTR = 0x0040 (LMEM_ZEROINIT | LMEM_FIXED) + LPTR = 0x0040 + security_descriptor = ctypes.windll.kernel32.LocalAlloc(LPTR, SECURITY_DESCRIPTOR_MIN_LENGTH) + + if not security_descriptor: + return None + + # Initialize the security descriptor + if not ctypes.windll.advapi32.InitializeSecurityDescriptor(security_descriptor, SECURITY_DESCRIPTOR_REVISION): + ctypes.windll.kernel32.LocalFree(security_descriptor) + return None + + # Set a NULL DACL which allows all access to everyone + # 3rd parameter is a BOOL that specifies whether to set a DACL (TRUE) or not (FALSE) + # 4th parameter is the DACL pointer (NULL for unrestricted access) + # 5th parameter is a BOOL that specifies whether the DACL was explicitly provided (TRUE) or defaulted (FALSE) + if not ctypes.windll.advapi32.SetSecurityDescriptorDacl(security_descriptor, True, None, False): ctypes.windll.kernel32.LocalFree(security_descriptor) return None # Create and initialize the security attributes structure sa = SECURITY_ATTRIBUTES() - sa.nLength = 1 # ctypes.sizeof(SECURITY_ATTRIBUTES) + sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) sa.lpSecurityDescriptor = security_descriptor - sa.bInheritHandle = 0 # Don't inherit handle + sa.bInheritHandle = False # Don't inherit handle + + # Store the security descriptor in a global variable to prevent it from being garbage collected + # and to allow cleanup when the module is unloaded + if not hasattr(_create_win32_security_attributes, "_security_descriptors"): + _create_win32_security_attributes._security_descriptors = [] + _create_win32_security_attributes._security_descriptors.append(security_descriptor) # Return a pointer that can be passed to the CUDA API return ctypes.addressof(sa) +# Add cleanup function for security descriptors +def _cleanup_security_descriptors(): + """Free any allocated security descriptors when the module is unloaded.""" + if hasattr(_create_win32_security_attributes, "_security_descriptors"): + for sd in _create_win32_security_attributes._security_descriptors: + if sd: + ctypes.windll.kernel32.LocalFree(sd) + _create_win32_security_attributes._security_descriptors.clear() + + +atexit.register(_cleanup_security_descriptors) + + class AsyncMempool(MemoryResource): """A CUDA memory pool for efficient memory allocation. From 7deb681fe6f3b43de5f3b70ebb17d9b181b77c88 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Wed, 26 Feb 2025 12:21:27 -0800 Subject: [PATCH 14/25] use library --- cuda_core/cuda/core/experimental/_memory.py | 85 ++++++++++++--------- 1 file changed, 48 insertions(+), 37 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_memory.py b/cuda_core/cuda/core/experimental/_memory.py index 7530f4f6c..73ea689fc 100644 --- a/cuda_core/cuda/core/experimental/_memory.py +++ b/cuda_core/cuda/core/experimental/_memory.py @@ -19,6 +19,21 @@ from cuda.core.experimental._stream import default_stream from cuda.core.experimental._utils import driver, handle_return +# Check if pywin32 is available on Windows +_PYWIN32_AVAILABLE = False +if platform.system() == "Windows": + try: + import win32security + + _PYWIN32_AVAILABLE = True + except ImportError: + import warnings + + warnings.warn( + "pywin32 module not found. For better IPC support on Windows, " "install it with: pip install pywin32", + stacklevel=2, + ) + PyCapsule = TypeVar("PyCapsule") @@ -287,8 +302,8 @@ def _get_platform_handle_type() -> int: def _create_win32_security_attributes(): """Creates a Windows SECURITY_ATTRIBUTES structure with default settings. - The security descriptor is configured with a DACL that allows access to everyone, - which is appropriate for shared memory that needs to be accessible across processes. + The security descriptor is configured to allow access across processes, + which is appropriate for shared memory. Returns: A pointer to a SECURITY_ATTRIBUTES structure or None if not on Windows. @@ -304,54 +319,50 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): ("bInheritHandle", ctypes.c_int), ] - # Constants for security descriptor creation - SECURITY_DESCRIPTOR_REVISION = 1 - SECURITY_DESCRIPTOR_MIN_LENGTH = 1024 + if _PYWIN32_AVAILABLE: + # Create a security descriptor using pywin32 + sd = win32security.SECURITY_DESCRIPTOR() - # Create a new security descriptor - use kernel32 for memory allocation - # LPTR = 0x0040 (LMEM_ZEROINIT | LMEM_FIXED) - LPTR = 0x0040 - security_descriptor = ctypes.windll.kernel32.LocalAlloc(LPTR, SECURITY_DESCRIPTOR_MIN_LENGTH) + # Create a blank DACL (this allows all access) + dacl = win32security.ACL() - if not security_descriptor: - return None + # Set the DACL to the security descriptor + sd.SetSecurityDescriptorDacl(1, dacl, 0) - # Initialize the security descriptor - if not ctypes.windll.advapi32.InitializeSecurityDescriptor(security_descriptor, SECURITY_DESCRIPTOR_REVISION): - ctypes.windll.kernel32.LocalFree(security_descriptor) - return None + # Create and initialize the security attributes structure + sa = SECURITY_ATTRIBUTES() + sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) + sa.lpSecurityDescriptor = ctypes.c_void_p(int(sd.SECURITY_DESCRIPTOR)) + sa.bInheritHandle = False - # Set a NULL DACL which allows all access to everyone - # 3rd parameter is a BOOL that specifies whether to set a DACL (TRUE) or not (FALSE) - # 4th parameter is the DACL pointer (NULL for unrestricted access) - # 5th parameter is a BOOL that specifies whether the DACL was explicitly provided (TRUE) or defaulted (FALSE) - if not ctypes.windll.advapi32.SetSecurityDescriptorDacl(security_descriptor, True, None, False): - ctypes.windll.kernel32.LocalFree(security_descriptor) - return None + # Store the security descriptor to prevent garbage collection + if not hasattr(_create_win32_security_attributes, "_security_descriptors"): + _create_win32_security_attributes._security_descriptors = [] + _create_win32_security_attributes._security_descriptors.append(sd) - # Create and initialize the security attributes structure - sa = SECURITY_ATTRIBUTES() - sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) - sa.lpSecurityDescriptor = security_descriptor - sa.bInheritHandle = False # Don't inherit handle + return ctypes.addressof(sa) + else: + # If pywin32 is not available, use a NULL security descriptor + # This is less secure but should work for testing + try: + sa = SECURITY_ATTRIBUTES() + sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) + sa.lpSecurityDescriptor = ctypes.c_void_p(0) # NULL security descriptor + sa.bInheritHandle = False - # Store the security descriptor in a global variable to prevent it from being garbage collected - # and to allow cleanup when the module is unloaded - if not hasattr(_create_win32_security_attributes, "_security_descriptors"): - _create_win32_security_attributes._security_descriptors = [] - _create_win32_security_attributes._security_descriptors.append(security_descriptor) + return ctypes.addressof(sa) - # Return a pointer that can be passed to the CUDA API - return ctypes.addressof(sa) + except Exception as e: + print(f"Warning: Failed to create security attributes: {e}") + return 0 # Return 0 as a fallback # Add cleanup function for security descriptors def _cleanup_security_descriptors(): """Free any allocated security descriptors when the module is unloaded.""" if hasattr(_create_win32_security_attributes, "_security_descriptors"): - for sd in _create_win32_security_attributes._security_descriptors: - if sd: - ctypes.windll.kernel32.LocalFree(sd) + # The security descriptors are now pywin32 objects that will be garbage collected + # or simple ctypes structures, so we just need to clear the list _create_win32_security_attributes._security_descriptors.clear() From c4e59f648a7215b6d59b8a49f581bca08a561c5f Mon Sep 17 00:00:00 2001 From: ksimpson Date: Wed, 26 Feb 2025 12:26:04 -0800 Subject: [PATCH 15/25] use library --- cuda_core/cuda/core/experimental/_memory.py | 119 ++++++++++++++++---- 1 file changed, 96 insertions(+), 23 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_memory.py b/cuda_core/cuda/core/experimental/_memory.py index 73ea689fc..80d0c9f0c 100644 --- a/cuda_core/cuda/core/experimental/_memory.py +++ b/cuda_core/cuda/core/experimental/_memory.py @@ -319,42 +319,108 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): ("bInheritHandle", ctypes.c_int), ] + # Try the pywin32 approach first if _PYWIN32_AVAILABLE: - # Create a security descriptor using pywin32 - sd = win32security.SECURITY_DESCRIPTOR() + try: + # Create a security descriptor using pywin32 + sd = win32security.SECURITY_DESCRIPTOR() + + # Create a blank DACL (this allows all access) + dacl = win32security.ACL() + + # Set the DACL to the security descriptor + sd.SetSecurityDescriptorDacl(1, dacl, 0) + + # Get the pointer to the security descriptor + # Different versions of pywin32 may have different ways to access the pointer + sd_pointer = None + + # Try different methods to get the pointer + try: + # Method 1: GetSecurityDescriptorSelf + sd_pointer = sd.GetSecurityDescriptorSelf() + except (AttributeError, TypeError): + try: + # Method 2: Direct attribute access (older versions) + sd_pointer = int(sd.SECURITY_DESCRIPTOR) + except (AttributeError, TypeError): + try: + # Method 3: Convert to int (some versions) + sd_pointer = int(sd) + except (TypeError, ValueError): + # Method 4: Last resort - use the handle if it's a PyHANDLE + if hasattr(sd, "handle"): + sd_pointer = sd.handle + + # If we couldn't get a pointer, fall back to NULL security descriptor + if sd_pointer is None: + raise ValueError("Could not get security descriptor pointer") + + # Create and initialize the security attributes structure + sa = SECURITY_ATTRIBUTES() + sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) + sa.lpSecurityDescriptor = ctypes.c_void_p(sd_pointer) + sa.bInheritHandle = False - # Create a blank DACL (this allows all access) - dacl = win32security.ACL() + # Store the security descriptor to prevent garbage collection + if not hasattr(_create_win32_security_attributes, "_security_descriptors"): + _create_win32_security_attributes._security_descriptors = [] + _create_win32_security_attributes._security_descriptors.append(sd) - # Set the DACL to the security descriptor - sd.SetSecurityDescriptorDacl(1, dacl, 0) + return ctypes.addressof(sa) + except Exception as e: + print(f"Warning: Failed to create security descriptor with pywin32: {e}") + # Fall through to the next method + + # Try direct ctypes approach if pywin32 failed or is not available + try: + # Constants for security descriptor creation + SECURITY_DESCRIPTOR_REVISION = 1 + SECURITY_DESCRIPTOR_MIN_LENGTH = 1024 + LPTR = 0x0040 # LMEM_ZEROINIT | LMEM_FIXED + + # Allocate memory for the security descriptor + security_descriptor = ctypes.windll.kernel32.LocalAlloc(LPTR, SECURITY_DESCRIPTOR_MIN_LENGTH) + if not security_descriptor: + raise OSError("Failed to allocate memory for security descriptor") + + # Initialize the security descriptor + if not ctypes.windll.advapi32.InitializeSecurityDescriptor(security_descriptor, SECURITY_DESCRIPTOR_REVISION): + ctypes.windll.kernel32.LocalFree(security_descriptor) + raise OSError("Failed to initialize security descriptor") + + # Set a NULL DACL which allows all access + if not ctypes.windll.advapi32.SetSecurityDescriptorDacl(security_descriptor, True, None, False): + ctypes.windll.kernel32.LocalFree(security_descriptor) + raise OSError("Failed to set security descriptor DACL") # Create and initialize the security attributes structure sa = SECURITY_ATTRIBUTES() sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) - sa.lpSecurityDescriptor = ctypes.c_void_p(int(sd.SECURITY_DESCRIPTOR)) + sa.lpSecurityDescriptor = ctypes.c_void_p(security_descriptor) sa.bInheritHandle = False - # Store the security descriptor to prevent garbage collection - if not hasattr(_create_win32_security_attributes, "_security_descriptors"): - _create_win32_security_attributes._security_descriptors = [] - _create_win32_security_attributes._security_descriptors.append(sd) + # Store the security descriptor for cleanup + if not hasattr(_create_win32_security_attributes, "_security_descriptors_ctypes"): + _create_win32_security_attributes._security_descriptors_ctypes = [] + _create_win32_security_attributes._security_descriptors_ctypes.append(security_descriptor) return ctypes.addressof(sa) - else: - # If pywin32 is not available, use a NULL security descriptor - # This is less secure but should work for testing - try: - sa = SECURITY_ATTRIBUTES() - sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) - sa.lpSecurityDescriptor = ctypes.c_void_p(0) # NULL security descriptor - sa.bInheritHandle = False + except Exception as e: + print(f"Warning: Failed to create security descriptor with ctypes: {e}") + # Fall through to the NULL security descriptor approach - return ctypes.addressof(sa) + # Last resort: NULL security descriptor + try: + sa = SECURITY_ATTRIBUTES() + sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) + sa.lpSecurityDescriptor = ctypes.c_void_p(0) # NULL security descriptor + sa.bInheritHandle = False - except Exception as e: - print(f"Warning: Failed to create security attributes: {e}") - return 0 # Return 0 as a fallback + return ctypes.addressof(sa) + except Exception as e: + print(f"Warning: Failed to create security attributes: {e}") + return 0 # Return 0 as a fallback # Add cleanup function for security descriptors @@ -365,6 +431,13 @@ def _cleanup_security_descriptors(): # or simple ctypes structures, so we just need to clear the list _create_win32_security_attributes._security_descriptors.clear() + # Clean up any ctypes security descriptors + if hasattr(_create_win32_security_attributes, "_security_descriptors_ctypes"): + for sd in _create_win32_security_attributes._security_descriptors_ctypes: + if sd: + ctypes.windll.kernel32.LocalFree(sd) + _create_win32_security_attributes._security_descriptors_ctypes.clear() + atexit.register(_cleanup_security_descriptors) From 22946327b0545d15d0d9f2ecbb01e6a94c8f1cdd Mon Sep 17 00:00:00 2001 From: ksimpson Date: Wed, 26 Feb 2025 12:30:28 -0800 Subject: [PATCH 16/25] use library --- cuda_core/cuda/core/experimental/_memory.py | 158 ++++---------------- 1 file changed, 32 insertions(+), 126 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_memory.py b/cuda_core/cuda/core/experimental/_memory.py index 80d0c9f0c..62ab2b7ca 100644 --- a/cuda_core/cuda/core/experimental/_memory.py +++ b/cuda_core/cuda/core/experimental/_memory.py @@ -15,25 +15,14 @@ import weakref from typing import Optional, Tuple, TypeVar +# Import win32security directly on Windows +if platform.system() == "Windows": + import win32security + from cuda.core.experimental._dlpack import DLDeviceType, make_py_capsule from cuda.core.experimental._stream import default_stream from cuda.core.experimental._utils import driver, handle_return -# Check if pywin32 is available on Windows -_PYWIN32_AVAILABLE = False -if platform.system() == "Windows": - try: - import win32security - - _PYWIN32_AVAILABLE = True - except ImportError: - import warnings - - warnings.warn( - "pywin32 module not found. For better IPC support on Windows, " "install it with: pip install pywin32", - stacklevel=2, - ) - PyCapsule = TypeVar("PyCapsule") @@ -319,125 +308,39 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): ("bInheritHandle", ctypes.c_int), ] - # Try the pywin32 approach first - if _PYWIN32_AVAILABLE: - try: - # Create a security descriptor using pywin32 - sd = win32security.SECURITY_DESCRIPTOR() - - # Create a blank DACL (this allows all access) - dacl = win32security.ACL() - - # Set the DACL to the security descriptor - sd.SetSecurityDescriptorDacl(1, dacl, 0) - - # Get the pointer to the security descriptor - # Different versions of pywin32 may have different ways to access the pointer - sd_pointer = None - - # Try different methods to get the pointer - try: - # Method 1: GetSecurityDescriptorSelf - sd_pointer = sd.GetSecurityDescriptorSelf() - except (AttributeError, TypeError): - try: - # Method 2: Direct attribute access (older versions) - sd_pointer = int(sd.SECURITY_DESCRIPTOR) - except (AttributeError, TypeError): - try: - # Method 3: Convert to int (some versions) - sd_pointer = int(sd) - except (TypeError, ValueError): - # Method 4: Last resort - use the handle if it's a PyHANDLE - if hasattr(sd, "handle"): - sd_pointer = sd.handle - - # If we couldn't get a pointer, fall back to NULL security descriptor - if sd_pointer is None: - raise ValueError("Could not get security descriptor pointer") - - # Create and initialize the security attributes structure - sa = SECURITY_ATTRIBUTES() - sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) - sa.lpSecurityDescriptor = ctypes.c_void_p(sd_pointer) - sa.bInheritHandle = False - - # Store the security descriptor to prevent garbage collection - if not hasattr(_create_win32_security_attributes, "_security_descriptors"): - _create_win32_security_attributes._security_descriptors = [] - _create_win32_security_attributes._security_descriptors.append(sd) - - return ctypes.addressof(sa) - except Exception as e: - print(f"Warning: Failed to create security descriptor with pywin32: {e}") - # Fall through to the next method - - # Try direct ctypes approach if pywin32 failed or is not available - try: - # Constants for security descriptor creation - SECURITY_DESCRIPTOR_REVISION = 1 - SECURITY_DESCRIPTOR_MIN_LENGTH = 1024 - LPTR = 0x0040 # LMEM_ZEROINIT | LMEM_FIXED - - # Allocate memory for the security descriptor - security_descriptor = ctypes.windll.kernel32.LocalAlloc(LPTR, SECURITY_DESCRIPTOR_MIN_LENGTH) - if not security_descriptor: - raise OSError("Failed to allocate memory for security descriptor") - - # Initialize the security descriptor - if not ctypes.windll.advapi32.InitializeSecurityDescriptor(security_descriptor, SECURITY_DESCRIPTOR_REVISION): - ctypes.windll.kernel32.LocalFree(security_descriptor) - raise OSError("Failed to initialize security descriptor") - - # Set a NULL DACL which allows all access - if not ctypes.windll.advapi32.SetSecurityDescriptorDacl(security_descriptor, True, None, False): - ctypes.windll.kernel32.LocalFree(security_descriptor) - raise OSError("Failed to set security descriptor DACL") - - # Create and initialize the security attributes structure - sa = SECURITY_ATTRIBUTES() - sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) - sa.lpSecurityDescriptor = ctypes.c_void_p(security_descriptor) - sa.bInheritHandle = False - - # Store the security descriptor for cleanup - if not hasattr(_create_win32_security_attributes, "_security_descriptors_ctypes"): - _create_win32_security_attributes._security_descriptors_ctypes = [] - _create_win32_security_attributes._security_descriptors_ctypes.append(security_descriptor) - - return ctypes.addressof(sa) - except Exception as e: - print(f"Warning: Failed to create security descriptor with ctypes: {e}") - # Fall through to the NULL security descriptor approach - - # Last resort: NULL security descriptor - try: - sa = SECURITY_ATTRIBUTES() - sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) - sa.lpSecurityDescriptor = ctypes.c_void_p(0) # NULL security descriptor - sa.bInheritHandle = False - - return ctypes.addressof(sa) - except Exception as e: - print(f"Warning: Failed to create security attributes: {e}") - return 0 # Return 0 as a fallback + # Create a security descriptor using pywin32 + sd = win32security.SECURITY_DESCRIPTOR() + + # Create a blank DACL (this allows all access) + dacl = win32security.ACL() + + # Set the DACL to the security descriptor + sd.SetSecurityDescriptorDacl(1, dacl, 0) + + # Get the pointer to the security descriptor + sd_pointer = sd.GetSecurityDescriptorSelf() + + # Create and initialize the security attributes structure + sa = SECURITY_ATTRIBUTES() + sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) + sa.lpSecurityDescriptor = ctypes.c_void_p(sd_pointer) + sa.bInheritHandle = False + + # Store the security descriptor to prevent garbage collection + if not hasattr(_create_win32_security_attributes, "_security_descriptors"): + _create_win32_security_attributes._security_descriptors = [] + _create_win32_security_attributes._security_descriptors.append(sd) + + return ctypes.addressof(sa) # Add cleanup function for security descriptors def _cleanup_security_descriptors(): """Free any allocated security descriptors when the module is unloaded.""" if hasattr(_create_win32_security_attributes, "_security_descriptors"): - # The security descriptors are now pywin32 objects that will be garbage collected - # or simple ctypes structures, so we just need to clear the list + # The security descriptors are pywin32 objects that will be garbage collected _create_win32_security_attributes._security_descriptors.clear() - # Clean up any ctypes security descriptors - if hasattr(_create_win32_security_attributes, "_security_descriptors_ctypes"): - for sd in _create_win32_security_attributes._security_descriptors_ctypes: - if sd: - ctypes.windll.kernel32.LocalFree(sd) - _create_win32_security_attributes._security_descriptors_ctypes.clear() - atexit.register(_cleanup_security_descriptors) @@ -542,6 +445,7 @@ def create( Maximum size in bytes that the memory pool can grow to ipc_enabled : bool, optional Whether to enable inter-process sharing capabilities. Default is False. + Note: On Windows, the pywin32 package is required for IPC support. win32_security_attributes : int, optional Custom Windows security attributes pointer. If 0 (default), a default security attributes structure will be created when needed on Windows platforms. @@ -555,6 +459,8 @@ def create( ------ ValueError If max_size is None + ImportError + If ipc_enabled is True on Windows but pywin32 is not installed CUDAError If pool creation fails """ From 62f1e93d9f9e8e019ea9096b92753310db71a443 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Wed, 26 Feb 2025 13:12:56 -0800 Subject: [PATCH 17/25] use library --- cuda_core/cuda/core/experimental/_memory.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_memory.py b/cuda_core/cuda/core/experimental/_memory.py index 62ab2b7ca..0756eb1fe 100644 --- a/cuda_core/cuda/core/experimental/_memory.py +++ b/cuda_core/cuda/core/experimental/_memory.py @@ -317,8 +317,10 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): # Set the DACL to the security descriptor sd.SetSecurityDescriptorDacl(1, dacl, 0) - # Get the pointer to the security descriptor - sd_pointer = sd.GetSecurityDescriptorSelf() + # Get the pointer to the security descriptor using the buffer interface + # PySECURITY_DESCRIPTOR objects support the buffer protocol + sd_view = memoryview(sd) + sd_pointer = ctypes.addressof(ctypes.c_char.from_buffer(sd_view)) # Create and initialize the security attributes structure sa = SECURITY_ATTRIBUTES() @@ -329,7 +331,7 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): # Store the security descriptor to prevent garbage collection if not hasattr(_create_win32_security_attributes, "_security_descriptors"): _create_win32_security_attributes._security_descriptors = [] - _create_win32_security_attributes._security_descriptors.append(sd) + _create_win32_security_attributes._security_descriptors.append((sd, sd_view)) # Keep both objects alive return ctypes.addressof(sa) From 2e5013a19e17c332be76f0ae3af1463ae7a2e52e Mon Sep 17 00:00:00 2001 From: ksimpson Date: Wed, 26 Feb 2025 13:15:41 -0800 Subject: [PATCH 18/25] use library --- cuda_core/cuda/core/experimental/_memory.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_memory.py b/cuda_core/cuda/core/experimental/_memory.py index 0756eb1fe..292f2be37 100644 --- a/cuda_core/cuda/core/experimental/_memory.py +++ b/cuda_core/cuda/core/experimental/_memory.py @@ -317,10 +317,15 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): # Set the DACL to the security descriptor sd.SetSecurityDescriptorDacl(1, dacl, 0) - # Get the pointer to the security descriptor using the buffer interface - # PySECURITY_DESCRIPTOR objects support the buffer protocol - sd_view = memoryview(sd) - sd_pointer = ctypes.addressof(ctypes.c_char.from_buffer(sd_view)) + # Get the raw bytes of the security descriptor using the buffer protocol + # This works because PySECURITY_DESCRIPTOR supports the buffer protocol + sd_bytes = bytes(sd) + + # Create a ctypes buffer from the bytes + sd_buffer = ctypes.create_string_buffer(sd_bytes) + + # Get the pointer to the buffer + sd_pointer = ctypes.cast(sd_buffer, ctypes.c_void_p).value # Create and initialize the security attributes structure sa = SECURITY_ATTRIBUTES() @@ -328,10 +333,10 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): sa.lpSecurityDescriptor = ctypes.c_void_p(sd_pointer) sa.bInheritHandle = False - # Store the security descriptor to prevent garbage collection + # Store the security descriptor and buffer to prevent garbage collection if not hasattr(_create_win32_security_attributes, "_security_descriptors"): _create_win32_security_attributes._security_descriptors = [] - _create_win32_security_attributes._security_descriptors.append((sd, sd_view)) # Keep both objects alive + _create_win32_security_attributes._security_descriptors.append((sd, sd_buffer)) # Keep both objects alive return ctypes.addressof(sa) From 0c82df8549e3260d5142c350f0bc39b8bc637681 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Wed, 26 Feb 2025 13:18:01 -0800 Subject: [PATCH 19/25] use library --- cuda_core/cuda/core/experimental/_memory.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_memory.py b/cuda_core/cuda/core/experimental/_memory.py index 292f2be37..4bed60d9c 100644 --- a/cuda_core/cuda/core/experimental/_memory.py +++ b/cuda_core/cuda/core/experimental/_memory.py @@ -333,10 +333,7 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): sa.lpSecurityDescriptor = ctypes.c_void_p(sd_pointer) sa.bInheritHandle = False - # Store the security descriptor and buffer to prevent garbage collection - if not hasattr(_create_win32_security_attributes, "_security_descriptors"): - _create_win32_security_attributes._security_descriptors = [] - _create_win32_security_attributes._security_descriptors.append((sd, sd_buffer)) # Keep both objects alive + print(f"sa: {sa}") return ctypes.addressof(sa) From 1c6ee59ab6d31cbd5aa9f17181b51cc56a6ec8d4 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Wed, 26 Feb 2025 13:36:52 -0800 Subject: [PATCH 20/25] use library --- cuda_core/cuda/core/experimental/_memory.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cuda_core/cuda/core/experimental/_memory.py b/cuda_core/cuda/core/experimental/_memory.py index 4bed60d9c..6d5a7b5b4 100644 --- a/cuda_core/cuda/core/experimental/_memory.py +++ b/cuda_core/cuda/core/experimental/_memory.py @@ -327,6 +327,9 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): # Get the pointer to the buffer sd_pointer = ctypes.cast(sd_buffer, ctypes.c_void_p).value + # print the contents of the buffer + print(f"sd_buffer: {sd_buffer}") + # Create and initialize the security attributes structure sa = SECURITY_ATTRIBUTES() sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) From 3074742363a5831542e5d09e18dfe58474f40875 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Wed, 26 Feb 2025 13:45:17 -0800 Subject: [PATCH 21/25] use library --- cuda_core/cuda/core/experimental/_memory.py | 55 +++++++++++---------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_memory.py b/cuda_core/cuda/core/experimental/_memory.py index 6d5a7b5b4..eb7d7de8e 100644 --- a/cuda_core/cuda/core/experimental/_memory.py +++ b/cuda_core/cuda/core/experimental/_memory.py @@ -15,9 +15,9 @@ import weakref from typing import Optional, Tuple, TypeVar -# Import win32security directly on Windows +# Remove pywin32 import if platform.system() == "Windows": - import win32security + from ctypes import windll, wintypes from cuda.core.experimental._dlpack import DLDeviceType, make_py_capsule from cuda.core.experimental._stream import default_stream @@ -300,43 +300,49 @@ def _create_win32_security_attributes(): if platform.system() != "Windows": return None + # Define constants needed for security descriptor creation + NULL = 0 + SECURITY_DESCRIPTOR_REVISION = 1 + # Define the Windows SECURITY_ATTRIBUTES structure class SECURITY_ATTRIBUTES(ctypes.Structure): _fields_ = [ - ("nLength", ctypes.c_ulong), - ("lpSecurityDescriptor", ctypes.c_void_p), - ("bInheritHandle", ctypes.c_int), + ("nLength", wintypes.DWORD), + ("lpSecurityDescriptor", wintypes.LPVOID), + ("bInheritHandle", wintypes.BOOL), ] - # Create a security descriptor using pywin32 - sd = win32security.SECURITY_DESCRIPTOR() - - # Create a blank DACL (this allows all access) - dacl = win32security.ACL() - - # Set the DACL to the security descriptor - sd.SetSecurityDescriptorDacl(1, dacl, 0) + # Get function pointers from Windows DLLs + InitializeSecurityDescriptor = windll.advapi32.InitializeSecurityDescriptor + InitializeSecurityDescriptor.argtypes = [wintypes.LPVOID, wintypes.DWORD] + InitializeSecurityDescriptor.restype = wintypes.BOOL - # Get the raw bytes of the security descriptor using the buffer protocol - # This works because PySECURITY_DESCRIPTOR supports the buffer protocol - sd_bytes = bytes(sd) + SetSecurityDescriptorDacl = windll.advapi32.SetSecurityDescriptorDacl + SetSecurityDescriptorDacl.argtypes = [wintypes.LPVOID, wintypes.BOOL, wintypes.LPVOID, wintypes.BOOL] + SetSecurityDescriptorDacl.restype = wintypes.BOOL - # Create a ctypes buffer from the bytes - sd_buffer = ctypes.create_string_buffer(sd_bytes) + # Create a security descriptor + sd_buffer = ctypes.create_string_buffer(64) # Size should be sufficient for a security descriptor + sd_pointer = ctypes.cast(sd_buffer, wintypes.LPVOID) - # Get the pointer to the buffer - sd_pointer = ctypes.cast(sd_buffer, ctypes.c_void_p).value + # Initialize the security descriptor + if not InitializeSecurityDescriptor(sd_pointer, SECURITY_DESCRIPTOR_REVISION): + raise ctypes.WinError() - # print the contents of the buffer - print(f"sd_buffer: {sd_buffer}") + # Set a NULL DACL (this allows all access) + if not SetSecurityDescriptorDacl(sd_pointer, True, NULL, False): + raise ctypes.WinError() # Create and initialize the security attributes structure sa = SECURITY_ATTRIBUTES() sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) - sa.lpSecurityDescriptor = ctypes.c_void_p(sd_pointer) + sa.lpSecurityDescriptor = sd_pointer sa.bInheritHandle = False - print(f"sa: {sa}") + # Store the security descriptor buffer to prevent garbage collection + if not hasattr(_create_win32_security_attributes, "_security_descriptors"): + _create_win32_security_attributes._security_descriptors = [] + _create_win32_security_attributes._security_descriptors.append(sd_buffer) return ctypes.addressof(sa) @@ -345,7 +351,6 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): def _cleanup_security_descriptors(): """Free any allocated security descriptors when the module is unloaded.""" if hasattr(_create_win32_security_attributes, "_security_descriptors"): - # The security descriptors are pywin32 objects that will be garbage collected _create_win32_security_attributes._security_descriptors.clear() From f35cf01196ad88eaa4add0ea7af67246d82e14e6 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Wed, 26 Feb 2025 13:54:19 -0800 Subject: [PATCH 22/25] use library --- cuda_core/cuda/core/experimental/_memory.py | 28 +++++++++++++-------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_memory.py b/cuda_core/cuda/core/experimental/_memory.py index eb7d7de8e..6983e213d 100644 --- a/cuda_core/cuda/core/experimental/_memory.py +++ b/cuda_core/cuda/core/experimental/_memory.py @@ -303,6 +303,7 @@ def _create_win32_security_attributes(): # Define constants needed for security descriptor creation NULL = 0 SECURITY_DESCRIPTOR_REVISION = 1 + SECURITY_DESCRIPTOR_MIN_LENGTH = 40 # Minimum size for a security descriptor # Define the Windows SECURITY_ATTRIBUTES structure class SECURITY_ATTRIBUTES(ctypes.Structure): @@ -321,17 +322,20 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): SetSecurityDescriptorDacl.argtypes = [wintypes.LPVOID, wintypes.BOOL, wintypes.LPVOID, wintypes.BOOL] SetSecurityDescriptorDacl.restype = wintypes.BOOL - # Create a security descriptor - sd_buffer = ctypes.create_string_buffer(64) # Size should be sufficient for a security descriptor + # Create a security descriptor with proper alignment + # Use ctypes.create_string_buffer to ensure proper memory alignment + sd_buffer = ctypes.create_string_buffer(SECURITY_DESCRIPTOR_MIN_LENGTH) sd_pointer = ctypes.cast(sd_buffer, wintypes.LPVOID) # Initialize the security descriptor if not InitializeSecurityDescriptor(sd_pointer, SECURITY_DESCRIPTOR_REVISION): - raise ctypes.WinError() + error = ctypes.WinError() + raise RuntimeError(f"Failed to initialize security descriptor: {error}") # Set a NULL DACL (this allows all access) if not SetSecurityDescriptorDacl(sd_pointer, True, NULL, False): - raise ctypes.WinError() + error = ctypes.WinError() + raise RuntimeError(f"Failed to set security descriptor DACL: {error}") # Create and initialize the security attributes structure sa = SECURITY_ATTRIBUTES() @@ -339,19 +343,23 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): sa.lpSecurityDescriptor = sd_pointer sa.bInheritHandle = False - # Store the security descriptor buffer to prevent garbage collection - if not hasattr(_create_win32_security_attributes, "_security_descriptors"): - _create_win32_security_attributes._security_descriptors = [] - _create_win32_security_attributes._security_descriptors.append(sd_buffer) + # Store both the security descriptor buffer and the security attributes structure + # to prevent garbage collection + if not hasattr(_create_win32_security_attributes, "_security_objects"): + _create_win32_security_attributes._security_objects = [] + # Keep both objects alive + _create_win32_security_attributes._security_objects.append((sd_buffer, sa)) + + # Return the pointer to the security attributes structure return ctypes.addressof(sa) # Add cleanup function for security descriptors def _cleanup_security_descriptors(): """Free any allocated security descriptors when the module is unloaded.""" - if hasattr(_create_win32_security_attributes, "_security_descriptors"): - _create_win32_security_attributes._security_descriptors.clear() + if hasattr(_create_win32_security_attributes, "_security_objects"): + _create_win32_security_attributes._security_objects.clear() atexit.register(_cleanup_security_descriptors) From bff995a25981174142d09d61d70e5aa5f7b1d19b Mon Sep 17 00:00:00 2001 From: ksimpson Date: Wed, 26 Feb 2025 13:55:28 -0800 Subject: [PATCH 23/25] use library --- cuda_core/cuda/core/experimental/_memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda_core/cuda/core/experimental/_memory.py b/cuda_core/cuda/core/experimental/_memory.py index 6983e213d..b483db381 100644 --- a/cuda_core/cuda/core/experimental/_memory.py +++ b/cuda_core/cuda/core/experimental/_memory.py @@ -341,7 +341,7 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): sa = SECURITY_ATTRIBUTES() sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) sa.lpSecurityDescriptor = sd_pointer - sa.bInheritHandle = False + sa.bInheritHandle = True # Store both the security descriptor buffer and the security attributes structure # to prevent garbage collection From 5fd7f46bd64d3afae6d6f5a88aee7f7d8806dd38 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Wed, 5 Mar 2025 10:08:23 -0800 Subject: [PATCH 24/25] remove windows stuff from async mempool until it is supported --- cuda_core/cuda/core/experimental/_memory.py | 114 ++------------------ cuda_core/cuda/core/experimental/_utils.py | 10 +- 2 files changed, 10 insertions(+), 114 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_memory.py b/cuda_core/cuda/core/experimental/_memory.py index b483db381..a8d5e882b 100644 --- a/cuda_core/cuda/core/experimental/_memory.py +++ b/cuda_core/cuda/core/experimental/_memory.py @@ -7,18 +7,10 @@ import abc # Register cleanup function to be called at interpreter shutdown -import atexit - -# Add ctypes import for Windows security attributes -import ctypes import platform import weakref from typing import Optional, Tuple, TypeVar -# Remove pywin32 import -if platform.system() == "Windows": - from ctypes import windll, wintypes - from cuda.core.experimental._dlpack import DLDeviceType, make_py_capsule from cuda.core.experimental._stream import default_stream from cuda.core.experimental._utils import driver, handle_return @@ -283,88 +275,11 @@ def _get_platform_handle_type() -> int: if system == "Linux": return driver.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR elif system == "Windows": - return driver.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_WIN32 + raise RuntimeError("IPC support is not yet available on Windows") else: raise RuntimeError(f"Unsupported platform: {system}") -def _create_win32_security_attributes(): - """Creates a Windows SECURITY_ATTRIBUTES structure with default settings. - - The security descriptor is configured to allow access across processes, - which is appropriate for shared memory. - - Returns: - A pointer to a SECURITY_ATTRIBUTES structure or None if not on Windows. - """ - if platform.system() != "Windows": - return None - - # Define constants needed for security descriptor creation - NULL = 0 - SECURITY_DESCRIPTOR_REVISION = 1 - SECURITY_DESCRIPTOR_MIN_LENGTH = 40 # Minimum size for a security descriptor - - # Define the Windows SECURITY_ATTRIBUTES structure - class SECURITY_ATTRIBUTES(ctypes.Structure): - _fields_ = [ - ("nLength", wintypes.DWORD), - ("lpSecurityDescriptor", wintypes.LPVOID), - ("bInheritHandle", wintypes.BOOL), - ] - - # Get function pointers from Windows DLLs - InitializeSecurityDescriptor = windll.advapi32.InitializeSecurityDescriptor - InitializeSecurityDescriptor.argtypes = [wintypes.LPVOID, wintypes.DWORD] - InitializeSecurityDescriptor.restype = wintypes.BOOL - - SetSecurityDescriptorDacl = windll.advapi32.SetSecurityDescriptorDacl - SetSecurityDescriptorDacl.argtypes = [wintypes.LPVOID, wintypes.BOOL, wintypes.LPVOID, wintypes.BOOL] - SetSecurityDescriptorDacl.restype = wintypes.BOOL - - # Create a security descriptor with proper alignment - # Use ctypes.create_string_buffer to ensure proper memory alignment - sd_buffer = ctypes.create_string_buffer(SECURITY_DESCRIPTOR_MIN_LENGTH) - sd_pointer = ctypes.cast(sd_buffer, wintypes.LPVOID) - - # Initialize the security descriptor - if not InitializeSecurityDescriptor(sd_pointer, SECURITY_DESCRIPTOR_REVISION): - error = ctypes.WinError() - raise RuntimeError(f"Failed to initialize security descriptor: {error}") - - # Set a NULL DACL (this allows all access) - if not SetSecurityDescriptorDacl(sd_pointer, True, NULL, False): - error = ctypes.WinError() - raise RuntimeError(f"Failed to set security descriptor DACL: {error}") - - # Create and initialize the security attributes structure - sa = SECURITY_ATTRIBUTES() - sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) - sa.lpSecurityDescriptor = sd_pointer - sa.bInheritHandle = True - - # Store both the security descriptor buffer and the security attributes structure - # to prevent garbage collection - if not hasattr(_create_win32_security_attributes, "_security_objects"): - _create_win32_security_attributes._security_objects = [] - - # Keep both objects alive - _create_win32_security_attributes._security_objects.append((sd_buffer, sa)) - - # Return the pointer to the security attributes structure - return ctypes.addressof(sa) - - -# Add cleanup function for security descriptors -def _cleanup_security_descriptors(): - """Free any allocated security descriptors when the module is unloaded.""" - if hasattr(_create_win32_security_attributes, "_security_objects"): - _create_win32_security_attributes._security_objects.clear() - - -atexit.register(_cleanup_security_descriptors) - - class AsyncMempool(MemoryResource): """A CUDA memory pool for efficient memory allocation. @@ -452,9 +367,7 @@ def _from_device(dev_id: int) -> AsyncMempool: return AsyncMempool._init(dev_id, handle, ipc_enabled=False, need_close=False) @staticmethod - def create( - dev_id: int, max_size: int, ipc_enabled: bool = False, win32_security_attributes: int = 0 - ) -> AsyncMempool: + def create(dev_id: int, max_size: int, ipc_enabled: bool = False) -> AsyncMempool: """Create a new memory pool. Parameters @@ -465,10 +378,7 @@ def create( Maximum size in bytes that the memory pool can grow to ipc_enabled : bool, optional Whether to enable inter-process sharing capabilities. Default is False. - Note: On Windows, the pywin32 package is required for IPC support. - win32_security_attributes : int, optional - Custom Windows security attributes pointer. If 0 (default), a default security - attributes structure will be created when needed on Windows platforms. + Note: IPC support is not yet available on Windows. Returns ------- @@ -479,14 +389,17 @@ def create( ------ ValueError If max_size is None - ImportError - If ipc_enabled is True on Windows but pywin32 is not installed + RuntimeError + If ipc_enabled is True on Windows (not yet supported) CUDAError If pool creation fails """ if max_size is None: raise ValueError("max_size must be provided when creating a new memory pool") + if platform.system() == "Windows" and ipc_enabled: + raise RuntimeError("IPC support is not yet available on Windows") + properties = driver.CUmemPoolProps() properties.allocType = driver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED properties.handleTypes = ( @@ -496,16 +409,7 @@ def create( properties.location.id = dev_id properties.location.type = driver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE properties.maxSize = max_size - - # Set up Windows security attributes if needed - if platform.system() == "Windows" and ipc_enabled: - if win32_security_attributes == 0: - # Create default security attributes if none provided - win32_security_attributes = _create_win32_security_attributes() - properties.win32SecurityAttributes = win32_security_attributes - else: - properties.win32SecurityAttributes = 0 - + properties.win32SecurityAttributes = 0 properties.usage = 0 handle = handle_return(driver.cuMemPoolCreate(properties)) diff --git a/cuda_core/cuda/core/experimental/_utils.py b/cuda_core/cuda/core/experimental/_utils.py index 2a2889675..3538ae6c1 100644 --- a/cuda_core/cuda/core/experimental/_utils.py +++ b/cuda_core/cuda/core/experimental/_utils.py @@ -14,7 +14,6 @@ from cuda import cuda as driver from cuda import cudart as runtime from cuda import nvrtc -import traceback class CUDAError(Exception): @@ -36,14 +35,7 @@ def _check_error(error, handle=None): if err == driver.CUresult.CUDA_SUCCESS: err, desc = driver.cuGetErrorString(error) if err == driver.CUresult.CUDA_SUCCESS: - stack = traceback.extract_stack() - # Get the last 2 frames (excluding the current one) - relevant_stack = stack[-4:-1] - stack_info = "\n".join( - f" File '{frame.filename}', line {frame.lineno}, in {frame.name}\n {frame.line}" - for frame in relevant_stack - ) - raise CUDAError(f"{name.decode()}: {desc.decode()}\n{stack_info}") + raise CUDAError(f"{name.decode()}: {desc.decode()}") else: raise CUDAError(f"unknown error: {error}") elif isinstance(error, runtime.cudaError_t): From c39d9a9a9b2fb14eb132d2fdfc2027f1ad84d8e6 Mon Sep 17 00:00:00 2001 From: ksimpson Date: Wed, 5 Mar 2025 11:42:05 -0800 Subject: [PATCH 25/25] remove windows stuff from async mempool until it is supported --- cuda_core/tests/test_memory.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index b8fa03c1c..23045a30a 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -227,6 +227,10 @@ def test_mempool(): pytest.skip("Test requires CUDA 12 or higher") device = Device() device.set_current() + + if not device.properties.memory_pools_supported: + pytest.skip("Device does not support mempool operations") + pool_size = 2097152 # 2MB size # Test basic pool creation @@ -309,6 +313,10 @@ def test_mempool_properties(property_name, expected_type): device = Device() device.set_current() + + if not device.properties.memory_pools_supported: + pytest.skip("Device does not support mempool operations") + pool_size = 2097152 # 2MB size mr = AsyncMempool.create(device.device_id, pool_size, ipc_enabled=False)