Skip to content

[NFC][SYCL] Prefer to pass context_impl by raw ptr/ref #18936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: sycl
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions sycl/include/sycl/ext/oneapi/memcpy2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void handler::ext_oneapi_memcpy2d(void *Dest, size_t DestPitch, const void *Src,
#endif

// Get the type of the pointers.
context Ctx = detail::createSyclObjFromImpl<context>(getContextImplPtr());
detail::context_impl &Ctx = getContextImpl();
usm::alloc SrcAllocType = get_pointer_type(Src, Ctx);
usm::alloc DestAllocType = get_pointer_type(Dest, Ctx);
bool SrcIsHost =
Expand Down Expand Up @@ -71,7 +71,7 @@ void handler::ext_oneapi_copy2d(const T *Src, size_t SrcPitch, T *Dest,
"to the width specified in 'ext_oneapi_copy2d'");

// Get the type of the pointers.
context Ctx = detail::createSyclObjFromImpl<context>(getContextImplPtr());
detail::context_impl &Ctx = getContextImpl();
usm::alloc SrcAllocType = get_pointer_type(Src, Ctx);
usm::alloc DestAllocType = get_pointer_type(Dest, Ctx);
bool SrcIsHost =
Expand Down Expand Up @@ -106,7 +106,7 @@ void handler::ext_oneapi_memset2d(void *Dest, size_t DestPitch, int Value,
"to the width specified in 'ext_oneapi_memset2d'");
T CharVal = static_cast<T>(Value);

context Ctx = detail::createSyclObjFromImpl<context>(getContextImplPtr());
detail::context_impl &Ctx = getContextImpl();
usm::alloc DestAllocType = get_pointer_type(Dest, Ctx);

// If the backends supports 2D fill we use that. Otherwise we use a fallback
Expand All @@ -130,7 +130,7 @@ void handler::ext_oneapi_fill2d(void *Dest, size_t DestPitch, const T &Pattern,
"Destination pitch must be greater than or equal "
"to the width specified in 'ext_oneapi_fill2d'");

context Ctx = detail::createSyclObjFromImpl<context>(getContextImplPtr());
detail::context_impl &Ctx = getContextImpl();
usm::alloc DestAllocType = get_pointer_type(Dest, Ctx);

// If the backends supports 2D fill we use that. Otherwise we use a fallback
Expand Down
3 changes: 3 additions & 0 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3530,7 +3530,10 @@ class __SYCL_EXPORT handler {
UserRange, KernelFunc};
}

#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
const std::shared_ptr<detail::context_impl> &getContextImplPtr() const;
#endif
detail::context_impl &getContextImpl() const;

// Checks if 2D memory operations are supported by the underlying platform.
bool supportsUSMMemcpy2D();
Expand Down
6 changes: 6 additions & 0 deletions sycl/include/sycl/interop_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ class interop_handle {
friend class detail::DispatchHostTask;
using ReqToMem = std::pair<detail::AccessorImplHost *, ur_mem_handle_t>;

#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
// Clean this up (no shared pointers). Not doing it right now because I expect
// there will be several iterations of simplifications possible and it would
// be hard to track which of them made their way into a minor public release
// and which didn't. Let's just clean it up once during ABI breaking window.
#endif
interop_handle(std::vector<ReqToMem> MemObjs,
const std::shared_ptr<detail::queue_impl> &Queue,
const std::shared_ptr<detail::device_impl> &Device,
Expand Down
11 changes: 11 additions & 0 deletions sycl/include/sycl/usm/usm_pointer_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,23 @@ inline namespace _V1 {
class device;
class context;

namespace detail {
class context_impl;
__SYCL_EXPORT usm::alloc get_pointer_type(const void *ptr, context_impl &ctxt);
} // namespace detail

// Pointer queries
/// Query the allocation type from a USM pointer
///
/// \param ptr is the USM pointer to query
/// \param ctxt is the sycl context the ptr was allocated in
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
inline usm::alloc get_pointer_type(const void *ptr, const context &ctxt) {
return get_pointer_type(ptr, *getSyclObjImpl(ctxt));
}
#else
__SYCL_EXPORT usm::alloc get_pointer_type(const void *ptr, const context &ctxt);
#endif

/// Queries the device against which the pointer was allocated
/// Throws an exception with errc::invalid error code if ptr is a host
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/async_alloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void *async_malloc(sycl::handler &h, sycl::usm::alloc kind, size_t size) {
sycl::make_error_code(sycl::errc::feature_not_supported),
"Only device backed asynchronous allocations are supported!");

auto &Adapter = h.getContextImplPtr()->getAdapter();
auto &Adapter = h.getContextImpl().getAdapter();

// Get CG event dependencies for this allocation.
const auto &DepEvents = h.impl->CGData.MEvents;
Expand Down Expand Up @@ -117,7 +117,7 @@ __SYCL_EXPORT void *async_malloc(const sycl::queue &q, sycl::usm::alloc kind,
__SYCL_EXPORT void *async_malloc_from_pool(sycl::handler &h, size_t size,
const memory_pool &pool) {

auto &Adapter = h.getContextImplPtr()->getAdapter();
auto &Adapter = h.getContextImpl().getAdapter();
auto &memPoolImpl = sycl::detail::getSyclObjImpl(pool);

// Get CG event dependencies for this allocation.
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/backend_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ inline namespace _V1 {
namespace detail {

template <class T> backend getImplBackend(const T &Impl) {
return Impl->getContextImplPtr()->getBackend();
return Impl->getContextImpl().getBackend();
}

} // namespace detail
Expand Down
26 changes: 13 additions & 13 deletions sycl/source/detail/bindless_images.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -813,9 +813,9 @@ get_image_memory_support(const image_descriptor &imageDescriptor,
const sycl::context &syclContext) {
std::shared_ptr<sycl::detail::device_impl> DevImpl =
sycl::detail::getSyclObjImpl(syclDevice);
std::shared_ptr<sycl::detail::context_impl> CtxImpl =
sycl::detail::getSyclObjImpl(syclContext);
const sycl::detail::AdapterPtr &Adapter = CtxImpl->getAdapter();
sycl::detail::context_impl &CtxImpl =
*sycl::detail::getSyclObjImpl(syclContext);
const sycl::detail::AdapterPtr &Adapter = CtxImpl.getAdapter();

ur_image_desc_t urDesc;
ur_image_format_t urFormat;
Expand All @@ -825,15 +825,15 @@ get_image_memory_support(const image_descriptor &imageDescriptor,
Adapter->call<sycl::errc::runtime,
sycl::detail::UrApiKind::
urBindlessImagesGetImageMemoryHandleTypeSupportExp>(
CtxImpl->getHandleRef(), DevImpl->getHandleRef(), &urDesc, &urFormat,
CtxImpl.getHandleRef(), DevImpl->getHandleRef(), &urDesc, &urFormat,
ur_exp_image_mem_type_t::UR_EXP_IMAGE_MEM_TYPE_USM_POINTER,
&supportsPointerAllocation);

ur_bool_t supportsOpaqueAllocation{0};
Adapter->call<sycl::errc::runtime,
sycl::detail::UrApiKind::
urBindlessImagesGetImageMemoryHandleTypeSupportExp>(
CtxImpl->getHandleRef(), DevImpl->getHandleRef(), &urDesc, &urFormat,
CtxImpl.getHandleRef(), DevImpl->getHandleRef(), &urDesc, &urFormat,
ur_exp_image_mem_type_t::UR_EXP_IMAGE_MEM_TYPE_OPAQUE_HANDLE,
&supportsOpaqueAllocation);

Expand Down Expand Up @@ -864,9 +864,9 @@ __SYCL_EXPORT bool is_image_handle_supported<unsampled_image_handle>(
const sycl::device &syclDevice, const sycl::context &syclContext) {
std::shared_ptr<sycl::detail::device_impl> DevImpl =
sycl::detail::getSyclObjImpl(syclDevice);
std::shared_ptr<sycl::detail::context_impl> CtxImpl =
sycl::detail::getSyclObjImpl(syclContext);
const sycl::detail::AdapterPtr &Adapter = CtxImpl->getAdapter();
sycl::detail::context_impl &CtxImpl =
*sycl::detail::getSyclObjImpl(syclContext);
const sycl::detail::AdapterPtr &Adapter = CtxImpl.getAdapter();

ur_image_desc_t urDesc;
ur_image_format_t urFormat;
Expand All @@ -881,7 +881,7 @@ __SYCL_EXPORT bool is_image_handle_supported<unsampled_image_handle>(
Adapter->call<sycl::errc::runtime,
sycl::detail::UrApiKind::
urBindlessImagesGetImageUnsampledHandleSupportExp>(
CtxImpl->getHandleRef(), DevImpl->getHandleRef(), &urDesc, &urFormat,
CtxImpl.getHandleRef(), DevImpl->getHandleRef(), &urDesc, &urFormat,
memHandleType, &supportsUnsampledHandle);

return supportsUnsampledHandle;
Expand All @@ -904,9 +904,9 @@ __SYCL_EXPORT bool is_image_handle_supported<sampled_image_handle>(
const sycl::device &syclDevice, const sycl::context &syclContext) {
std::shared_ptr<sycl::detail::device_impl> DevImpl =
sycl::detail::getSyclObjImpl(syclDevice);
std::shared_ptr<sycl::detail::context_impl> CtxImpl =
sycl::detail::getSyclObjImpl(syclContext);
const sycl::detail::AdapterPtr &Adapter = CtxImpl->getAdapter();
sycl::detail::context_impl &CtxImpl =
*sycl::detail::getSyclObjImpl(syclContext);
const sycl::detail::AdapterPtr &Adapter = CtxImpl.getAdapter();

ur_image_desc_t urDesc;
ur_image_format_t urFormat;
Expand All @@ -921,7 +921,7 @@ __SYCL_EXPORT bool is_image_handle_supported<sampled_image_handle>(
Adapter->call<
sycl::errc::runtime,
sycl::detail::UrApiKind::urBindlessImagesGetImageSampledHandleSupportExp>(
CtxImpl->getHandleRef(), DevImpl->getHandleRef(), &urDesc, &urFormat,
CtxImpl.getHandleRef(), DevImpl->getHandleRef(), &urDesc, &urFormat,
memHandleType, &supportsSampledHandle);

return supportsSampledHandle;
Expand Down
6 changes: 4 additions & 2 deletions sycl/source/detail/context_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,14 @@ void GetCapabilitiesIntersectionSet(const std::vector<sycl::device> &Devices,

// We're under sycl/source and these won't be exported but it's way more
// convenient to be able to reference them without extra `detail::`.
inline auto get_ur_handles(const sycl::context &syclContext) {
sycl::detail::context_impl &Ctx = *sycl::detail::getSyclObjImpl(syclContext);
inline auto get_ur_handles(sycl::detail::context_impl &Ctx) {
ur_context_handle_t urCtx = Ctx.getHandleRef();
const sycl::detail::Adapter *Adapter = Ctx.getAdapter().get();
return std::tuple{urCtx, Adapter};
}
inline auto get_ur_handles(const sycl::context &syclContext) {
return get_ur_handles(*sycl::detail::getSyclObjImpl(syclContext));
}
inline auto get_ur_handles(const sycl::device &syclDevice,
const sycl::context &syclContext) {
auto [urCtx, Adapter] = get_ur_handles(syclContext);
Expand Down
6 changes: 3 additions & 3 deletions sycl/source/detail/device_image_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,13 +570,13 @@ class device_image_impl {

ur_native_handle_t getNative() const {
assert(MProgram);
const auto &ContextImplPtr = detail::getSyclObjImpl(MContext);
const AdapterPtr &Adapter = ContextImplPtr->getAdapter();
context_impl &ContextImpl = *detail::getSyclObjImpl(MContext);
const AdapterPtr &Adapter = ContextImpl.getAdapter();

ur_native_handle_t NativeProgram = 0;
Adapter->call<UrApiKind::urProgramGetNativeHandle>(MProgram,
&NativeProgram);
if (ContextImplPtr->getBackend() == backend::opencl)
if (ContextImpl.getBackend() == backend::opencl)
__SYCL_OCL_CALL(clRetainProgram, ur::cast<cl_program>(NativeProgram));

return NativeProgram;
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/handler_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class handler_impl {
template <typename Self = handler_impl> context_impl &get_context() {
Self *self = this;
if (auto *Queue = self->get_queue_or_null())
return *Queue->getContextImplPtr();
return Queue->getContextImpl();
else
return *self->get_graph().getContextImplPtr();
}
Expand Down
11 changes: 5 additions & 6 deletions sycl/source/detail/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

namespace sycl {
inline namespace _V1 {
using ContextImplPtr = std::shared_ptr<sycl::detail::context_impl>;
namespace detail {
void waitEvents(std::vector<sycl::event> DepEvents) {
for (auto SyclEvent : DepEvents) {
Expand Down Expand Up @@ -59,10 +58,10 @@ retrieveKernelBinary(queue_impl &Queue, KernelNameStrRefT KernelName,
if (DeviceImage == DeviceImages.end()) {
return {nullptr, nullptr};
}
auto ContextImpl = Queue.getContextImplPtr();
context_impl &ContextImpl = Queue.getContextImpl();
ur_program_handle_t Program =
detail::ProgramManager::getInstance().createURProgram(
**DeviceImage, *ContextImpl, {createSyclObjFromImpl<device>(Dev)});
**DeviceImage, ContextImpl, {createSyclObjFromImpl<device>(Dev)});
return {*DeviceImage, Program};
}

Expand All @@ -80,11 +79,11 @@ retrieveKernelBinary(queue_impl &Queue, KernelNameStrRefT KernelName,
DeviceImage = SyclKernelImpl->getDeviceImage()->get_bin_image_ref();
Program = SyclKernelImpl->getDeviceImage()->get_ur_program_ref();
} else {
auto ContextImpl = Queue.getContextImplPtr();
context_impl &ContextImpl = Queue.getContextImpl();
DeviceImage = &detail::ProgramManager::getInstance().getDeviceImage(
KernelName, *ContextImpl, Dev);
KernelName, ContextImpl, Dev);
Program = detail::ProgramManager::getInstance().createURProgram(
*DeviceImage, *ContextImpl, {createSyclObjFromImpl<device>(Dev)});
*DeviceImage, ContextImpl, {createSyclObjFromImpl<device>(Dev)});
}
return {DeviceImage, Program};
}
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/kernel_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class kernel_impl {
bool isInterop() const { return MIsInterop; }

ur_program_handle_t getProgramRef() const { return MProgram; }
ContextImplPtr getContextImplPtr() const { return MContext; }
context_impl &getContextImpl() const { return *MContext; }

std::mutex &getNoncacheableEnqueueMutex() const {
return MNoncacheableEnqueueMutex;
Expand Down
8 changes: 4 additions & 4 deletions sycl/source/detail/queue_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ template <> device queue_impl::get_info<info::queue::device>() const {
template <>
typename info::platform::version::return_type
queue_impl::get_backend_info<info::platform::version>() const {
if (getContextImplPtr()->getBackend() != backend::opencl) {
if (getContextImpl().getBackend() != backend::opencl) {
throw sycl::exception(errc::backend_mismatch,
"the info::platform::version info descriptor can "
"only be queried with an OpenCL backend");
Expand All @@ -93,7 +93,7 @@ queue_impl::get_backend_info<info::platform::version>() const {
template <>
typename info::device::version::return_type
queue_impl::get_backend_info<info::device::version>() const {
if (getContextImplPtr()->getBackend() != backend::opencl) {
if (getContextImpl().getBackend() != backend::opencl) {
throw sycl::exception(errc::backend_mismatch,
"the info::device::version info descriptor can only "
"be queried with an OpenCL backend");
Expand All @@ -106,7 +106,7 @@ queue_impl::get_backend_info<info::device::version>() const {
template <>
typename info::device::backend_version::return_type
queue_impl::get_backend_info<info::device::backend_version>() const {
if (getContextImplPtr()->getBackend() != backend::ext_oneapi_level_zero) {
if (getContextImpl().getBackend() != backend::ext_oneapi_level_zero) {
throw sycl::exception(errc::backend_mismatch,
"the info::device::backend_version info descriptor "
"can only be queried with a Level Zero backend");
Expand Down Expand Up @@ -731,7 +731,7 @@ ur_native_handle_t queue_impl::getNative(int32_t &NativeHandleDesc) const {

Adapter->call<UrApiKind::urQueueGetNativeHandle>(MQueue, &UrNativeDesc,
&Handle);
if (getContextImplPtr()->getBackend() == backend::opencl)
if (getContextImpl().getBackend() == backend::opencl)
__SYCL_OCL_CALL(clRetainCommandQueue, ur::cast<cl_command_queue>(Handle));

return Handle;
Expand Down
5 changes: 3 additions & 2 deletions sycl/source/detail/queue_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {

const AdapterPtr &getAdapter() const { return MContext->getAdapter(); }

// TODO: stop using it in existing code. New code must NOT use this!
const ContextImplPtr &getContextImplPtr() const { return MContext; }

context_impl &getContextImpl() const { return *MContext; }
Expand Down Expand Up @@ -651,7 +652,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
void revisitUnenqueuedCommandsState(const EventImplPtr &CompletedHostTask);

static ContextImplPtr getContext(queue_impl *Queue) {
return Queue ? Queue->getContextImplPtr() : nullptr;
return Queue ? Queue->getContextImpl().shared_from_this() : nullptr;
}
static ContextImplPtr getContext(const QueueImplPtr &Queue) {
return getContext(Queue.get());
Expand Down Expand Up @@ -984,7 +985,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
mutable std::mutex MMutex;

device_impl &MDevice;
const ContextImplPtr MContext;
const std::shared_ptr<context_impl> MContext;

/// These events are tracked, but not owned, by the queue.
std::vector<std::weak_ptr<event_impl>> MEventsWeak;
Expand Down
14 changes: 9 additions & 5 deletions sycl/source/detail/sampler_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,14 @@ sampler_impl::~sampler_impl() {
}

ur_sampler_handle_t
sampler_impl::getOrCreateSampler(const ContextImplPtr &ContextImpl) {
sampler_impl::getOrCreateSampler(context_impl &ContextImpl) {
// Just for the `MContextToSampler` lookups. Probably the type of it should be
// changed.
std::shared_ptr<context_impl> ContextImplPtr = ContextImpl.shared_from_this();

{
std::lock_guard<std::mutex> Lock(MMutex);
auto It = MContextToSampler.find(ContextImpl);
auto It = MContextToSampler.find(ContextImplPtr);
if (It != MContextToSampler.end())
return It->second;
}
Expand Down Expand Up @@ -135,18 +139,18 @@ sampler_impl::getOrCreateSampler(const ContextImplPtr &ContextImpl) {

ur_result_t errcode_ret = UR_RESULT_SUCCESS;
ur_sampler_handle_t resultSampler = nullptr;
const AdapterPtr &Adapter = ContextImpl->getAdapter();
const AdapterPtr &Adapter = ContextImpl.getAdapter();

errcode_ret = Adapter->call_nocheck<UrApiKind::urSamplerCreate>(
ContextImpl->getHandleRef(), &desc, &resultSampler);
ContextImpl.getHandleRef(), &desc, &resultSampler);

if (errcode_ret == UR_RESULT_ERROR_UNSUPPORTED_FEATURE)
throw sycl::exception(sycl::errc::feature_not_supported,
"Images are not supported by this device.");

Adapter->checkUrResult(errcode_ret);
std::lock_guard<std::mutex> Lock(MMutex);
MContextToSampler[ContextImpl] = resultSampler;
MContextToSampler[ContextImplPtr] = resultSampler;

return resultSampler;
}
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/sampler_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class sampler_impl {

coordinate_normalization_mode get_coordinate_normalization_mode() const;

ur_sampler_handle_t getOrCreateSampler(const ContextImplPtr &ContextImpl);
ur_sampler_handle_t getOrCreateSampler(context_impl &ContextImpl);

~sampler_impl();

Expand Down
Loading
Loading