Skip to content

Commit 8eaa0d1

Browse files
[NFC][SYCL] Use raw context_impl in event_impl::[set|get]Context
1 parent 610d42a commit 8eaa0d1

File tree

10 files changed

+33
-32
lines changed

10 files changed

+33
-32
lines changed

sycl/source/detail/event_impl.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ void event_impl::initContextIfNeeded() {
3838
return;
3939

4040
const device SyclDevice;
41-
this->setContextImpl(
42-
detail::queue_impl::getDefaultOrNew(*detail::getSyclObjImpl(SyclDevice)));
41+
MIsHostEvent = false;
42+
MContext =
43+
detail::queue_impl::getDefaultOrNew(*detail::getSyclObjImpl(SyclDevice));
4344
}
4445

4546
event_impl::~event_impl() {
@@ -140,9 +141,10 @@ void event_impl::setHandle(const ur_event_handle_t &UREvent) {
140141
MEvent.store(UREvent);
141142
}
142143

143-
const ContextImplPtr &event_impl::getContextImpl() {
144+
context_impl &event_impl::getContextImpl() {
144145
initContextIfNeeded();
145-
return MContext;
146+
assert(MContext && "Trying to get context from a host event!");
147+
return *MContext;
146148
}
147149

148150
const AdapterPtr &event_impl::getAdapter() {
@@ -152,9 +154,13 @@ const AdapterPtr &event_impl::getAdapter() {
152154

153155
void event_impl::setStateIncomplete() { MState = HES_NotComplete; }
154156

155-
void event_impl::setContextImpl(const ContextImplPtr &Context) {
157+
void event_impl::setContextImpl(context_impl &Context) {
158+
MIsHostEvent = false;
159+
MContext = Context.shared_from_this();
160+
}
161+
void event_impl::setContextImpl(context_impl *Context) {
156162
MIsHostEvent = Context == nullptr;
157-
MContext = Context;
163+
MContext = Context ? Context->shared_from_this() : nullptr;
158164
}
159165

160166
event_impl::event_impl(ur_event_handle_t Event, const context &SyclContext,
@@ -178,7 +184,7 @@ event_impl::event_impl(ur_event_handle_t Event, const context &SyclContext,
178184
event_impl::event_impl(queue_impl &Queue, private_tag)
179185
: MQueue{Queue.weak_from_this()},
180186
MIsProfilingEnabled{Queue.MIsProfilingEnabled} {
181-
this->setContextImpl(Queue.getContextImplPtr());
187+
this->setContextImpl(Queue.getContextImpl());
182188
MState.store(HES_Complete);
183189
}
184190

sycl/source/detail/event_impl.hpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,21 +174,18 @@ class event_impl : public std::enable_shared_from_this<event_impl> {
174174
void setHandle(const ur_event_handle_t &UREvent);
175175

176176
/// Returns context that is associated with this event.
177-
///
178-
/// \return a shared pointer to a valid context_impl.
179-
const ContextImplPtr &getContextImpl();
177+
context_impl &getContextImpl();
180178

181179
/// \return the Adapter associated with the context of this event.
182180
/// Should be called when this is not a Host Event.
183181
const AdapterPtr &getAdapter();
184182

185183
/// Associate event with the context.
186184
///
187-
/// Provided UrContext inside ContextImplPtr must be associated
185+
/// Provided UrContext inside Context must be associated
188186
/// with the UrEvent object stored in this class
189-
///
190-
/// @param Context is a shared pointer to an instance of valid context_impl.
191-
void setContextImpl(const ContextImplPtr &Context);
187+
void setContextImpl(context_impl &Context);
188+
void setContextImpl(context_impl *Context);
192189

193190
/// Clear the event state
194191
void setStateIncomplete();

sycl/source/detail/graph_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ exec_graph_impl::enqueue(sycl::detail::queue_impl &Queue,
10371037

10381038
auto CreateNewEvent([&]() {
10391039
auto NewEvent = sycl::detail::event_impl::create_device_event(Queue);
1040-
NewEvent->setContextImpl(Queue.getContextImplPtr());
1040+
NewEvent->setContextImpl(Queue.getContextImpl());
10411041
NewEvent->setStateIncomplete();
10421042
return NewEvent;
10431043
});

sycl/source/detail/queue_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ queue_impl::get_backend_info<info::device::backend_version>() const {
121121
static event prepareSYCLEventAssociatedWithQueue(
122122
const std::shared_ptr<detail::queue_impl> &QueueImpl) {
123123
auto EventImpl = detail::event_impl::create_device_event(*QueueImpl);
124-
EventImpl->setContextImpl(detail::getSyclObjImpl(QueueImpl->get_context()));
124+
EventImpl->setContextImpl(QueueImpl->getContextImpl());
125125
EventImpl->setStateIncomplete();
126126
return detail::createSyclObjFromImpl<event>(EventImpl);
127127
}

sycl/source/detail/reduction.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ __SYCL_EXPORT void
208208
addCounterInit(handler &CGH, std::shared_ptr<sycl::detail::queue_impl> &Queue,
209209
std::shared_ptr<int> &Counter) {
210210
auto EventImpl = detail::event_impl::create_device_event(*Queue);
211-
EventImpl->setContextImpl(detail::getSyclObjImpl(Queue->get_context()));
211+
EventImpl->setContextImpl(Queue->getContextImpl());
212212
EventImpl->setStateIncomplete();
213213
ur_event_handle_t UREvent = nullptr;
214214
MemoryManager::fill_usm(Counter.get(), *Queue, sizeof(int), {0}, {},

sycl/source/detail/scheduler/commands.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -537,10 +537,8 @@ void Command::waitForEvents(QueueImplPtr Queue,
537537
RequiredEventsPerContext;
538538

539539
for (const EventImplPtr &Event : EventImpls) {
540-
ContextImplPtr Context = Event->getContextImpl();
541-
assert(Context.get() &&
542-
"Only non-host events are expected to be waited for here");
543-
RequiredEventsPerContext[Context.get()].push_back(Event);
540+
context_impl &Context = Event->getContextImpl();
541+
RequiredEventsPerContext[&Context].push_back(Event);
544542
}
545543

546544
for (auto &CtxWithEvents : RequiredEventsPerContext) {
@@ -580,7 +578,7 @@ Command::Command(
580578
MEvent->setSubmittedQueue(MWorkerQueue);
581579
MEvent->setCommand(this);
582580
if (MQueue)
583-
MEvent->setContextImpl(MQueue->getContextImplPtr());
581+
MEvent->setContextImpl(MQueue->getContextImpl());
584582
MEvent->setStateIncomplete();
585583
MEnqueueStatus = EnqueueResultT::SyclEnqueueReady;
586584

@@ -785,9 +783,9 @@ Command *Command::processDepEvent(EventImplPtr DepEvent, const DepDesc &Dep,
785783

786784
Command *ConnectionCmd = nullptr;
787785

788-
ContextImplPtr DepEventContext = DepEvent->getContextImpl();
786+
context_impl &DepEventContext = DepEvent->getContextImpl();
789787
// If contexts don't match we'll connect them using host task
790-
if (DepEventContext != WorkerContext && WorkerContext) {
788+
if (&DepEventContext != WorkerContext.get() && WorkerContext) {
791789
Scheduler::GraphBuilder &GB = Scheduler::getInstance().MGraphBuilder;
792790
ConnectionCmd = GB.connectDepEvent(this, DepEvent, Dep, ToCleanUp);
793791
} else
@@ -1303,7 +1301,7 @@ ur_result_t ReleaseCommand::enqueueImp() {
13031301

13041302
std::shared_ptr<event_impl> UnmapEventImpl =
13051303
event_impl::create_device_event(*Queue);
1306-
UnmapEventImpl->setContextImpl(Queue->getContextImplPtr());
1304+
UnmapEventImpl->setContextImpl(Queue->getContextImpl());
13071305
UnmapEventImpl->setStateIncomplete();
13081306
ur_event_handle_t UREvent = nullptr;
13091307

@@ -1522,7 +1520,7 @@ MemCpyCommand::MemCpyCommand(Requirement SrcReq,
15221520
MSrcAllocaCmd(SrcAllocaCmd), MDstReq(std::move(DstReq)),
15231521
MDstAllocaCmd(DstAllocaCmd) {
15241522
if (MSrcQueue) {
1525-
MEvent->setContextImpl(MSrcQueue->getContextImplPtr());
1523+
MEvent->setContextImpl(MSrcQueue->getContextImpl());
15261524
}
15271525

15281526
MWorkerQueue = !MQueue ? MSrcQueue : MQueue;
@@ -1695,7 +1693,7 @@ MemCpyCommandHost::MemCpyCommandHost(Requirement SrcReq,
16951693
MSrcQueue(SrcQueue), MSrcReq(std::move(SrcReq)),
16961694
MSrcAllocaCmd(SrcAllocaCmd), MDstReq(std::move(DstReq)), MDstPtr(DstPtr) {
16971695
if (MSrcQueue) {
1698-
MEvent->setContextImpl(MSrcQueue->getContextImplPtr());
1696+
MEvent->setContextImpl(MSrcQueue->getContextImpl());
16991697
}
17001698

17011699
MWorkerQueue = !MQueue ? MSrcQueue : MQueue;

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1217,7 +1217,7 @@ void Scheduler::GraphBuilder::removeRecordForMemObj(SYCLMemObjI *MemObject) {
12171217
Command *Scheduler::GraphBuilder::connectDepEvent(
12181218
Command *const Cmd, const EventImplPtr &DepEvent, const DepDesc &Dep,
12191219
std::vector<Command *> &ToCleanUp) {
1220-
assert(Cmd->getWorkerContext() != DepEvent->getContextImpl());
1220+
assert(Cmd->getWorkerContext().get() != &DepEvent->getContextImpl());
12211221

12221222
// construct Host Task type command manually and make it depend on DepEvent
12231223
ExecCGCommand *ConnectCmd = nullptr;

sycl/source/detail/scheduler/scheduler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ bool Scheduler::CheckEventReadiness(context_impl &Context,
691691
return SyclEventImplPtr->isCompleted();
692692
}
693693
// Cross-context dependencies can't be passed to the backend directly.
694-
if (SyclEventImplPtr->getContextImpl().get() != &Context)
694+
if (&SyclEventImplPtr->getContextImpl() != &Context)
695695
return false;
696696

697697
// A nullptr here means that the commmand does not produce a UR event or it

sycl/source/handler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ event handler::finalize() {
610610
detail::queue_impl &Queue = impl->get_queue();
611611
LastEventImpl->setQueue(Queue);
612612
LastEventImpl->setWorkerQueue(Queue.weak_from_this());
613-
LastEventImpl->setContextImpl(impl->get_context().shared_from_this());
613+
LastEventImpl->setContextImpl(impl->get_context());
614614
LastEventImpl->setStateIncomplete();
615615
LastEventImpl->setSubmissionTime();
616616

sycl/unittests/scheduler/QueueFlushing.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ TEST_F(SchedulerTest, QueueFlushing) {
151151
access::mode::read_write};
152152
std::shared_ptr<detail::event_impl> DepEvent =
153153
detail::event_impl::create_device_event(*QueueImplB);
154-
DepEvent->setContextImpl(QueueImplB->getContextImplPtr());
154+
DepEvent->setContextImpl(QueueImplB->getContextImpl());
155155

156156
ur_event_handle_t UREvent = mock::createDummyHandle<ur_event_handle_t>();
157157

@@ -171,7 +171,7 @@ TEST_F(SchedulerTest, QueueFlushing) {
171171
queue TempQueue{Ctx, default_selector_v};
172172
detail::queue_impl &TempQueueImpl = *detail::getSyclObjImpl(TempQueue);
173173
DepEvent = detail::event_impl::create_device_event(TempQueueImpl);
174-
DepEvent->setContextImpl(TempQueueImpl.getContextImplPtr());
174+
DepEvent->setContextImpl(TempQueueImpl.getContextImpl());
175175

176176
ur_event_handle_t UREvent = mock::createDummyHandle<ur_event_handle_t>();
177177

0 commit comments

Comments
 (0)