Skip to content

Commit

Permalink
[L0 v2] introduce raii wrapper for UR handles
Browse files Browse the repository at this point in the history
Some entities (e.g. devices) do not need to be retained
as they are owned by the platform. For such cases, only
validate RefCount instead of acutally increasing/decreasing it.
  • Loading branch information
igchor committed Jan 21, 2025
1 parent f058cb2 commit 7a2affd
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 69 deletions.
75 changes: 75 additions & 0 deletions source/adapters/level_zero/v2/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <ze_api.h>

#include "../common.hpp"
#include "../ur_interface_loader.hpp"
#include "logger/ur_logger.hpp"
namespace {
#define DECLARE_DESTROY_FUNCTION(name) \
Expand Down Expand Up @@ -118,5 +119,79 @@ using ze_context_handle_t = HANDLE_WRAPPER_TYPE(::ze_context_handle_t,
using ze_command_list_handle_t = HANDLE_WRAPPER_TYPE(::ze_command_list_handle_t,
zeCommandListDestroy);

// retain should point to a function that will be called during
// construction of ref_counted and release to a function that
// will be called in the destructor.
template <typename URHandle, ur_result_t (*retain)(URHandle),
ur_result_t (*release)(URHandle)>
struct ref_counted {
ref_counted(URHandle handle) : handle(handle) { retain(handle); }

~ref_counted() { release(handle); }

operator URHandle() const { return handle; }
URHandle operator->() const { return handle; }

ref_counted(const ref_counted &) = delete;
ref_counted &operator=(const ref_counted &) = delete;

ref_counted(ref_counted &&other) {
handle = other.handle;
other.handle = nullptr;
}

ref_counted &operator=(ref_counted &&other) {
if (this == &other) {
return *this;
}

if (handle) {
release(handle);
}

handle = other.handle;
other.handle = nullptr;
return *this;
}

URHandle get() const { return handle; }

private:
URHandle handle;
};

template <typename URHandle> struct ref_counted_traits;

#define DECLARE_REF_COUNTER_TRAITS(URHandle, retainFn, releaseFn) \
template <> struct ref_counted_traits<URHandle> { \
static ur_result_t retain(URHandle handle) { \
assert(handle); \
return retainFn(handle); \
} \
static ur_result_t release(URHandle handle) { \
assert(handle); \
return releaseFn(handle); \
} \
};

// This version of ref_counted calls retain/release functions.
template <typename URHandle>
using rc = ref_counted<URHandle, ref_counted_traits<URHandle>::retain,
ref_counted_traits<URHandle>::release>;

DECLARE_REF_COUNTER_TRAITS(::ur_context_handle_t,
ur::level_zero::urContextRetain,
ur::level_zero::urContextRelease);
DECLARE_REF_COUNTER_TRAITS(::ur_mem_handle_t, ur::level_zero::urMemRetain,
ur::level_zero::urMemRelease);
DECLARE_REF_COUNTER_TRAITS(::ur_program_handle_t,
ur::level_zero::urProgramRetain,
ur::level_zero::urProgramRelease);
DECLARE_REF_COUNTER_TRAITS(::ur_queue_handle_t, ur::level_zero::urQueueRetain,
ur::level_zero::urQueueRelease);
DECLARE_REF_COUNTER_TRAITS(::ur_kernel_handle_t, ur::level_zero::urKernelRetain,
ur::level_zero::urKernelRelease);

} // namespace raii

} // namespace v2
3 changes: 3 additions & 0 deletions source/adapters/level_zero/v2/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ ur_result_t ur_event_handle_t_::release() {
if (isTimestamped() && !getEventEndTimestamp()) {
// L0 will write end timestamp to this event some time in the future,
// so we can't release it yet.

// If this code is being executed, queue has to be valid (queue cannot
// be released before all operations complete).
assert(hQueue);
hQueue->deferEventFree(this);
return UR_RESULT_SUCCESS;
Expand Down
26 changes: 25 additions & 1 deletion source/adapters/level_zero/v2/event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ struct ur_event_handle_t_ : _ur_object {
const ze_event_handle_t hZeEvent;

// queue and commandType that this event is associated with, set by enqueue
// commands
// commands.
ur_queue_handle_t hQueue = nullptr;
ur_command_t commandType = UR_COMMAND_FORCE_UINT32;

Expand Down Expand Up @@ -134,3 +134,27 @@ struct ur_native_event_t : ur_event_handle_t_ {
private:
v2::raii::ze_event_handle_t zeEvent;
};

namespace v2 {

inline ur_result_t
deferredEventRetain([[maybe_unused]] ur_event_handle_t hEvent) {
assert(hEvent);
assert(reinterpret_cast<_ur_object *>(hEvent)->RefCount.load() == 0);
return UR_RESULT_SUCCESS;
}

inline ur_result_t
defferedEventRelease([[maybe_unused]] ur_event_handle_t hEvent) {
assert(hEvent);
assert(reinterpret_cast<_ur_object *>(hEvent)->RefCount.load() == 0);
return UR_RESULT_SUCCESS;
}

// deferredEvents have refCount equal to 0, the only operation that can be
// called on them is releaseDeferred()
using deferred_event_handle_t =
v2::raii::ref_counted<ur_event_handle_t, deferredEventRetain,
defferedEventRelease>;

} // namespace v2
1 change: 0 additions & 1 deletion source/adapters/level_zero/v2/event_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ namespace v2 {

class event_pool {
public:
// store weak reference to the queue as event_pool is part of the queue
event_pool(ur_context_handle_t hContext,
std::unique_ptr<event_provider> Provider)
: hContext(hContext), provider(std::move(Provider)),
Expand Down
2 changes: 1 addition & 1 deletion source/adapters/level_zero/v2/event_pool_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace v2 {
event_pool_cache::event_pool_cache(ur_context_handle_t hContext,
size_t max_devices,
ProviderCreateFunc ProviderCreate)
: hContext(hContext), providerCreate(ProviderCreate) {
: hContext(std::move(hContext)), providerCreate(ProviderCreate) {
pools.resize(max_devices * (1ULL << EVENT_FLAGS_USED_BITS));
}

Expand Down
40 changes: 11 additions & 29 deletions source/adapters/level_zero/v2/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,10 @@ ur_single_device_kernel_t::ur_single_device_kernel_t(ur_device_handle_t hDevice,
};
}

ur_result_t ur_single_device_kernel_t::release() {
hKernel.reset();
return UR_RESULT_SUCCESS;
}

ur_kernel_handle_t_::ur_kernel_handle_t_(ur_program_handle_t hProgram,
const char *kernelName)
: hProgram(hProgram),
deviceKernels(hProgram->Context->getPlatform()->getNumDevices()) {
ur::level_zero::urProgramRetain(hProgram);

for (auto &Dev : hProgram->AssociatedDevices) {
auto zeDevice = Dev->ZeDevice;
// Program may be associated with all devices from the context but built
Expand Down Expand Up @@ -75,8 +68,6 @@ ur_kernel_handle_t_::ur_kernel_handle_t_(
const ur_kernel_native_properties_t *pProperties)
: hProgram(hProgram),
deviceKernels(context ? context->getPlatform()->getNumDevices() : 0) {
ur::level_zero::urProgramRetain(hProgram);

auto ownZeHandle = pProperties ? pProperties->isNativeHandleOwned : false;

ze_kernel_handle_t zeKernel = ur_cast<ze_kernel_handle_t>(hNativeKernel);
Expand All @@ -94,24 +85,6 @@ ur_kernel_handle_t_::ur_kernel_handle_t_(
completeInitialization();
}

ur_result_t ur_kernel_handle_t_::release() {
if (!RefCount.decrementAndTest())
return UR_RESULT_SUCCESS;

// manually release kernels to allow errors to be propagated
for (auto &singleDeviceKernelOpt : deviceKernels) {
if (singleDeviceKernelOpt.has_value()) {
singleDeviceKernelOpt.value().hKernel.reset();
}
}

UR_CALL_THROWS(ur::level_zero::urProgramRelease(hProgram));

delete this;

return UR_RESULT_SUCCESS;
}

void ur_kernel_handle_t_::completeInitialization() {
// Cache kernel name. Should be the same for all devices
assert(deviceKernels.size() > 0);
Expand Down Expand Up @@ -318,6 +291,15 @@ std::vector<char> ur_kernel_handle_t_::getSourceAttributes() const {
return attributes;
}

ur_result_t ur_kernel_handle_t_::release() {
if (!RefCount.decrementAndTest())
return UR_RESULT_SUCCESS;

delete this;

return UR_RESULT_SUCCESS;
}

namespace ur::level_zero {
ur_result_t urKernelCreate(ur_program_handle_t hProgram,
const char *pKernelName,
Expand Down Expand Up @@ -365,8 +347,8 @@ ur_result_t urKernelRetain(
}

ur_result_t urKernelRelease(
/// [in] handle for the Kernel to release
ur_kernel_handle_t hKernel) try {
ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to release
) try {
return hKernel->release();
} catch (...) {
return exceptionToResult(std::current_exception());
Expand Down
5 changes: 2 additions & 3 deletions source/adapters/level_zero/v2/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
struct ur_single_device_kernel_t {
ur_single_device_kernel_t(ur_device_handle_t hDevice,
ze_kernel_handle_t hKernel, bool ownZeHandle);
ur_result_t release();

ur_device_handle_t hDevice;
v2::raii::ze_kernel_handle_t hKernel;
Expand Down Expand Up @@ -74,7 +73,7 @@ struct ur_kernel_handle_t_ : _ur_object {

std::vector<char> getSourceAttributes() const;

// Perform cleanup.
// Decrease the refcount and call destructor if new refcount == 0.
ur_result_t release();

// Add a pending memory allocation for which device is not yet known.
Expand All @@ -92,7 +91,7 @@ struct ur_kernel_handle_t_ : _ur_object {

private:
// Keep the program of the kernel.
const ur_program_handle_t hProgram;
const v2::raii::rc<ur_program_handle_t> hProgram;

// Vector of ur_single_device_kernel_t indexed by deviceIndex().
std::vector<std::optional<ur_single_device_kernel_t>> deviceKernels;
Expand Down
17 changes: 3 additions & 14 deletions source/adapters/level_zero/v2/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ void *ur_discrete_mem_handle_t::allocateOnDevice(ur_device_handle_t hDevice,
hContext, hDevice, nullptr, UR_USM_TYPE_DEVICE, size, &ptr));

deviceAllocations[id] =
usm_unique_ptr_t(ptr, [hContext = this->hContext](void *ptr) {
usm_unique_ptr_t(ptr, [hContext = this->hContext.get()](void *ptr) {
auto ret = hContext->getDefaultUSMPool()->free(ptr);
if (ret != UR_RESULT_SUCCESS) {
logger::error("Failed to free device memory: {}", ret);
Expand Down Expand Up @@ -230,7 +230,7 @@ ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(
devicePtr = allocateOnDevice(hDevice, size);
} else {
deviceAllocations[hDevice->Id.value()] = usm_unique_ptr_t(
devicePtr, [hContext = this->hContext, ownZePtr](void *ptr) {
devicePtr, [hContext = this->hContext.get(), ownZePtr](void *ptr) {
if (!ownZePtr) {
return;
}
Expand Down Expand Up @@ -361,22 +361,11 @@ static bool useHostBuffer(ur_context_handle_t hContext) {
ZE_DEVICE_PROPERTY_FLAG_INTEGRATED;
}

namespace ur::level_zero {
ur_result_t urMemRetain(ur_mem_handle_t hMem);
ur_result_t urMemRelease(ur_mem_handle_t hMem);
} // namespace ur::level_zero

ur_mem_sub_buffer_t::ur_mem_sub_buffer_t(ur_mem_handle_t hParent, size_t offset,
size_t size,
device_access_mode_t accessMode)
: ur_mem_handle_t_(hParent->getContext(), size, accessMode),
hParent(hParent), offset(offset), size(size) {
ur::level_zero::urMemRetain(hParent);
}

ur_mem_sub_buffer_t::~ur_mem_sub_buffer_t() {
ur::level_zero::urMemRelease(hParent);
}
hParent(hParent), offset(offset), size(size) {}

void *ur_mem_sub_buffer_t::getDevicePtr(
ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
Expand Down
6 changes: 2 additions & 4 deletions source/adapters/level_zero/v2/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct ur_mem_handle_t_ : private _ur_object {

protected:
const device_access_mode_t accessMode;
const ur_context_handle_t hContext;
const v2::raii::rc<ur_context_handle_t> hContext;
const size_t size;
};

Expand Down Expand Up @@ -140,7 +140,6 @@ struct ur_discrete_mem_handle_t : public ur_mem_handle_t_ {
std::vector<usm_unique_ptr_t> deviceAllocations;

// Specifies device on which the latest allocation resides.
// If null, there is no allocation.
ur_device_handle_t activeAllocationDevice = nullptr;

// If not null, copy the buffer content back to this memory on release.
Expand All @@ -157,7 +156,6 @@ struct ur_discrete_mem_handle_t : public ur_mem_handle_t_ {
struct ur_mem_sub_buffer_t : public ur_mem_handle_t_ {
ur_mem_sub_buffer_t(ur_mem_handle_t hParent, size_t offset, size_t size,
device_access_mode_t accesMode);
~ur_mem_sub_buffer_t();

void *
getDevicePtr(ur_device_handle_t, device_access_mode_t, size_t offset,
Expand All @@ -172,7 +170,7 @@ struct ur_mem_sub_buffer_t : public ur_mem_handle_t_ {
ur_shared_mutex &getMutex() override;

private:
ur_mem_handle_t hParent;
v2::raii::rc<ur_mem_handle_t> hParent;
size_t offset;
size_t size;
};
17 changes: 4 additions & 13 deletions source/adapters/level_zero/v2/queue_immediate_in_order.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ ur_queue_immediate_in_order_t::getSignalEvent(ur_event_handle_t *hUserEvent,
ur_command_t commandType) {
if (hUserEvent) {
*hUserEvent = eventPool->allocate();
(*hUserEvent)->resetQueueAndCommand(this, commandType);
(*hUserEvent)->resetQueueAndCommand(ur_queue_handle_t(this), commandType);
return *hUserEvent;
} else {
return nullptr;
Expand All @@ -121,7 +121,7 @@ ur_queue_immediate_in_order_t::queueGetInfo(ur_queue_info_t propName,
// TODO: consider support for queue properties and size
switch ((uint32_t)propName) { // cast to avoid warnings on EXT enum values
case UR_QUEUE_INFO_CONTEXT:
return ReturnValue(hContext);
return ReturnValue(hContext.get());
case UR_QUEUE_INFO_DEVICE:
return ReturnValue(hDevice);
case UR_QUEUE_INFO_REFERENCE_COUNT:
Expand Down Expand Up @@ -162,7 +162,7 @@ ur_result_t ur_queue_immediate_in_order_t::queueRelease() {

void ur_queue_immediate_in_order_t::deferEventFree(ur_event_handle_t hEvent) {
std::unique_lock<ur_shared_mutex> lock(this->Mutex);
deferredEvents.push_back(hEvent);
deferredEvents.emplace_back(hEvent);
}

ur_result_t ur_queue_immediate_in_order_t::queueGetNativeHandle(
Expand All @@ -184,24 +184,15 @@ ur_result_t ur_queue_immediate_in_order_t::queueFinish() {
ZE2UR_CALL(zeCommandListHostSynchronize,
(handler.commandList.get(), UINT64_MAX));

// Free deferred events
for (auto &hEvent : deferredEvents) {
UR_CALL(hEvent->releaseDeferred());
}
deferredEvents.clear();

// Free deferred kernels
for (auto &hKernel : submittedKernels) {
UR_CALL(hKernel->release());
}
submittedKernels.clear();

return UR_RESULT_SUCCESS;
}

void ur_queue_immediate_in_order_t::recordSubmittedKernel(
ur_kernel_handle_t hKernel) {
submittedKernels.push_back(hKernel);
submittedKernels.emplace_back(hKernel);
hKernel->RefCount.increment();
}

Expand Down
Loading

0 comments on commit 7a2affd

Please sign in to comment.