Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions libdevice/sanitizer/msan_rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,15 @@ inline void ReportError(const uint32_t size, const char __SYCL_CONSTANT__ *file,

// This function is only used for shadow propagation
template <typename T>
void GroupAsyncCopy(uptr Dest, uptr Src, size_t NumElements, size_t Stride) {
void GroupAsyncCopy(uptr Dest, uptr Src, size_t NumElements, size_t Stride,
bool StrideOnSrc) {
auto DestPtr = (__SYCL_GLOBAL__ T *)Dest;
auto SrcPtr = (const __SYCL_GLOBAL__ T *)Src;
for (size_t i = 0; i < NumElements; i++) {
DestPtr[i] = SrcPtr[i * Stride];
if (StrideOnSrc)
DestPtr[i] = SrcPtr[i * Stride];
else
DestPtr[i * Stride] = SrcPtr[i];
}
}

Expand Down Expand Up @@ -748,16 +752,20 @@ __msan_unpoison_strided_copy(uptr dest, uint32_t dest_as, uptr src,

switch (element_size) {
case 1:
GroupAsyncCopy<int8_t>(shadow_dest, shadow_src, counts, stride);
GroupAsyncCopy<int8_t>(shadow_dest, shadow_src, counts, stride,
src_as == ADDRESS_SPACE_GLOBAL);
break;
case 2:
GroupAsyncCopy<int16_t>(shadow_dest, shadow_src, counts, stride);
GroupAsyncCopy<int16_t>(shadow_dest, shadow_src, counts, stride,
src_as == ADDRESS_SPACE_GLOBAL);
break;
case 4:
GroupAsyncCopy<int32_t>(shadow_dest, shadow_src, counts, stride);
GroupAsyncCopy<int32_t>(shadow_dest, shadow_src, counts, stride,
src_as == ADDRESS_SPACE_GLOBAL);
break;
case 8:
GroupAsyncCopy<int64_t>(shadow_dest, shadow_src, counts, stride);
GroupAsyncCopy<int64_t>(shadow_dest, shadow_src, counts, stride,
src_as == ADDRESS_SPACE_GLOBAL);
break;
default:
__spirv_ocl_printf(__msan_print_strided_copy_unsupport_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
assert((void *)Device != nullptr && "Device cannot be nullptr");

std::scoped_lock<ur_shared_mutex> Guard(Mutex);
auto CI = getAsanInterceptor()->getContextInfo(Context);
auto &Allocation = Allocations[Device];
ur_result_t URes = UR_RESULT_SUCCESS;
if (!Allocation) {
Expand All @@ -106,9 +107,9 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
}

if (HostPtr) {
ManagedQueue Queue(Context, Device);
ur_queue_handle_t InternalQueue = CI->getInternalQueue(Device);
URes = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, Allocation, HostPtr, Size, 0, nullptr, nullptr);
InternalQueue, true, Allocation, HostPtr, Size, 0, nullptr, nullptr);
if (URes != UR_RESULT_SUCCESS) {
UR_LOG_L(
getContext()->logger, ERR,
Expand Down Expand Up @@ -147,10 +148,10 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {

// Copy data from last synced device to host
{
ManagedQueue Queue(Context, LastSyncedDevice.hDevice);
ur_queue_handle_t InternalQueue = CI->getInternalQueue(Device);
URes = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, HostAllocation, LastSyncedDevice.MemHandle, Size, 0,
nullptr, nullptr);
InternalQueue, true, HostAllocation, LastSyncedDevice.MemHandle, Size,
0, nullptr, nullptr);
if (URes != UR_RESULT_SUCCESS) {
UR_LOG_L(getContext()->logger, ERR,
"Failed to migrate memory buffer data");
Expand All @@ -160,9 +161,10 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {

// Sync data back to device
{
ManagedQueue Queue(Context, Device);
ur_queue_handle_t InternalQueue = CI->getInternalQueue(Device);
URes = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, Allocation, HostAllocation, Size, 0, nullptr, nullptr);
InternalQueue, true, Allocation, HostAllocation, Size, 0, nullptr,
nullptr);
if (URes != UR_RESULT_SUCCESS) {
UR_LOG_L(getContext()->logger, ERR,
"Failed to migrate memory buffer data");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreate(
std::shared_ptr<ContextInfo> CtxInfo =
getAsanInterceptor()->getContextInfo(hContext);
for (const auto &hDevice : CtxInfo->DeviceList) {
ManagedQueue InternalQueue(hContext, hDevice);
ur_queue_handle_t InternalQueue = CtxInfo->getInternalQueue(hDevice);
char *Handle = nullptr;
UR_CALL(pMemBuffer->getHandle(hDevice, Handle));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,15 @@ ur_result_t AsanInterceptor::preLaunchKernel(ur_kernel_handle_t Kernel,
auto ContextInfo = getContextInfo(Context);
auto DeviceInfo = getDeviceInfo(Device);

ManagedQueue InternalQueue(Context, Device);
if (!InternalQueue) {
UR_LOG_L(getContext()->logger, ERR, "Failed to create internal queue");
return UR_RESULT_ERROR_INVALID_QUEUE;
}
ur_queue_handle_t InternalQueue = ContextInfo->getInternalQueue(Device);

UR_CALL(prepareLaunch(ContextInfo, DeviceInfo, InternalQueue, Kernel,
LaunchInfo));

UR_CALL(updateShadowMemory(ContextInfo, DeviceInfo, InternalQueue));

UR_CALL(getContext()->urDdiTable.Queue.pfnFinish(InternalQueue));

return UR_RESULT_SUCCESS;
}

Expand Down Expand Up @@ -467,6 +465,7 @@ ur_result_t AsanInterceptor::unregisterProgram(ur_program_handle_t Program) {

ur_result_t AsanInterceptor::registerSpirKernels(ur_program_handle_t Program) {
auto Context = GetContext(Program);
auto CI = getContextInfo(Context);
std::vector<ur_device_handle_t> Devices = GetDevices(Program);

for (auto Device : Devices) {
Expand All @@ -484,11 +483,11 @@ ur_result_t AsanInterceptor::registerSpirKernels(ur_program_handle_t Program) {
assert((MetadataSize % sizeof(SpirKernelInfo) == 0) &&
"SpirKernelMetadata size is not correct");

ManagedQueue Queue(Context, Device);
ur_queue_handle_t InternalQueue = CI->getInternalQueue(Device);

std::vector<SpirKernelInfo> SKInfo(NumOfSpirKernel);
Result = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, &SKInfo[0], MetadataPtr,
InternalQueue, true, &SKInfo[0], MetadataPtr,
sizeof(SpirKernelInfo) * NumOfSpirKernel, 0, nullptr, nullptr);
if (Result != UR_RESULT_SUCCESS) {
UR_LOG_L(getContext()->logger, ERR, "Can't read the value of <{}>: {}",
Expand All @@ -504,7 +503,7 @@ ur_result_t AsanInterceptor::registerSpirKernels(ur_program_handle_t Program) {
}
std::vector<char> KernelNameV(SKI.Size);
Result = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, KernelNameV.data(), (void *)SKI.KernelName,
InternalQueue, true, KernelNameV.data(), (void *)SKI.KernelName,
sizeof(char) * SKI.Size, 0, nullptr, nullptr);
if (Result != UR_RESULT_SUCCESS) {
UR_LOG_L(getContext()->logger, ERR, "Can't read kernel name: {}",
Expand Down Expand Up @@ -537,7 +536,7 @@ AsanInterceptor::registerDeviceGlobals(ur_program_handle_t Program) {
assert(ProgramInfo != nullptr && "unregistered program!");

for (auto Device : Devices) {
ManagedQueue Queue(Context, Device);
ur_queue_handle_t InternalQueue = ContextInfo->getInternalQueue(Device);

size_t MetadataSize;
void *MetadataPtr;
Expand All @@ -554,7 +553,7 @@ AsanInterceptor::registerDeviceGlobals(ur_program_handle_t Program) {
"DeviceGlobal metadata size is not correct");
std::vector<DeviceGlobalInfo> GVInfos(NumOfDeviceGlobal);
Result = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, &GVInfos[0], MetadataPtr,
InternalQueue, true, &GVInfos[0], MetadataPtr,
sizeof(DeviceGlobalInfo) * NumOfDeviceGlobal, 0, nullptr, nullptr);
if (Result != UR_RESULT_SUCCESS) {
UR_LOG_L(getContext()->logger, ERR, "Device Global[{}] Read Failed: {}",
Expand Down Expand Up @@ -932,6 +931,8 @@ bool ProgramInfo::isKernelInstrumented(ur_kernel_handle_t Kernel) const {
ContextInfo::~ContextInfo() {
Stats.Print(Handle);

InternalQueueMap.clear();

[[maybe_unused]] ur_result_t URes;
if (USMPool) {
URes = getContext()->urDdiTable.USM.pfnPoolRelease(USMPool);
Expand Down Expand Up @@ -971,6 +972,13 @@ ur_usm_pool_handle_t ContextInfo::getUSMPool() {
return USMPool;
}

ur_queue_handle_t ContextInfo::getInternalQueue(ur_device_handle_t Device) {
std::scoped_lock<ur_shared_mutex> Guard(InternalQueueMapMutex);
if (!InternalQueueMap[Device])
InternalQueueMap[Device].emplace(Handle, Device);
return *InternalQueueMap[Device];
}

AsanRuntimeDataWrapper::~AsanRuntimeDataWrapper() {
[[maybe_unused]] ur_result_t Result;
if (Host.LocalArgs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "asan_statistics.hpp"
#include "sanitizer_common/sanitizer_common.hpp"
#include "sanitizer_common/sanitizer_options.hpp"
#include "sanitizer_common/sanitizer_utils.hpp"
#include "ur_sanitizer_layer.hpp"

#include <memory>
Expand Down Expand Up @@ -143,6 +144,10 @@ struct ContextInfo {
std::vector<ur_device_handle_t> DeviceList;
std::unordered_map<ur_device_handle_t, AllocInfoList> AllocInfosMap;

ur_shared_mutex InternalQueueMapMutex;
std::unordered_map<ur_device_handle_t, std::optional<ManagedQueue>>
InternalQueueMap;

AsanStatsWrapper Stats;

explicit ContextInfo(ur_context_handle_t Context) : Handle(Context) {
Expand All @@ -163,6 +168,8 @@ struct ContextInfo {
}

ur_usm_pool_handle_t getUSMPool();

ur_queue_handle_t getInternalQueue(ur_device_handle_t);
};

struct AsanRuntimeDataWrapper {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ ur_usm_type_t GetUSMType(ur_context_handle_t Context, const void *MemPtr) {
} // namespace

ManagedQueue::ManagedQueue(ur_context_handle_t Context,
ur_device_handle_t Device) {
ur_device_handle_t Device, bool IsOutOfOrder) {
ur_queue_properties_t Prop{UR_STRUCTURE_TYPE_QUEUE_PROPERTIES, nullptr,
UR_QUEUE_FLAG_OUT_OF_ORDER_EXEC_MODE_ENABLE};
[[maybe_unused]] auto Result = getContext()->urDdiTable.Queue.pfnCreate(
Context, Device, nullptr, &Handle);
Context, Device, IsOutOfOrder ? &Prop : nullptr, &Handle);
assert(Result == UR_RESULT_SUCCESS && "Failed to create ManagedQueue");
UR_LOG_L(getContext()->logger, DEBUG, ">>> ManagedQueue {}", (void *)Handle);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
namespace ur_sanitizer_layer {

struct ManagedQueue {
ManagedQueue(ur_context_handle_t Context, ur_device_handle_t Device);
ManagedQueue(ur_context_handle_t Context, ur_device_handle_t Device,
bool IsOutOfOrder = false);
~ManagedQueue();

// Disable copy semantics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {

std::scoped_lock<ur_shared_mutex> Guard(Mutex);
auto &Allocation = Allocations[Device];
auto CI = getTsanInterceptor()->getContextInfo(Context);
ur_result_t URes = UR_RESULT_SUCCESS;
if (!Allocation) {
ur_usm_desc_t USMDesc{};
Expand All @@ -114,7 +115,7 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
}

if (HostPtr) {
ManagedQueue Queue(Context, Device);
ur_queue_handle_t Queue = CI->getInternalQueue(Device);
URes = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, Allocation, HostPtr, Size, 0, nullptr, nullptr);
if (URes != UR_RESULT_SUCCESS) {
Expand Down Expand Up @@ -155,7 +156,7 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {

// Copy data from last synced device to host
{
ManagedQueue Queue(Context, LastSyncedDevice.hDevice);
ur_queue_handle_t Queue = CI->getInternalQueue(Device);
URes = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, HostAllocation, LastSyncedDevice.MemHandle, Size, 0,
nullptr, nullptr);
Expand All @@ -168,7 +169,7 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {

// Sync data back to device
{
ManagedQueue Queue(Context, Device);
ur_queue_handle_t Queue = CI->getInternalQueue(Device);
URes = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Queue, true, Allocation, HostAllocation, Size, 0, nullptr, nullptr);
if (URes != UR_RESULT_SUCCESS) {
Expand All @@ -185,8 +186,10 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
}

ur_result_t MemBuffer::free() {
for (const auto &[_, Ptr] : Allocations) {
ur_result_t URes = getTsanInterceptor()->releaseMemory(Context, Ptr);
for (const auto &[Device, Ptr] : Allocations) {
ur_result_t URes = Device
? getTsanInterceptor()->releaseMemory(Context, Ptr)
: getContext()->urDdiTable.USM.pfnFree(Context, Ptr);
if (URes != UR_RESULT_SUCCESS) {
UR_LOG_L(getContext()->logger, ERR, "Failed to free buffer handle {}",
(void *)Ptr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ ur_result_t urMemBufferCreate(
std::shared_ptr<ContextInfo> CtxInfo =
getTsanInterceptor()->getContextInfo(hContext);
for (const auto &hDevice : CtxInfo->DeviceList) {
ManagedQueue InternalQueue(hContext, hDevice);
ur_queue_handle_t InternalQueue = CtxInfo->getInternalQueue(hDevice);
char *Handle = nullptr;
UR_CALL(pMemBuffer->getHandle(hDevice, Handle));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
Expand Down
Loading