Skip to content

Commit f088e38

Browse files
authored
[SYCL] Add implementation of host-interop-task and test. (#1748)
This patch is number two in series of patches for interop part of host task. This patch introduces an API to enqueue host-task with `interop_handle` argument See the proposal https://github.com/codeplaysoftware/standards-proposals/blob/master/host_task/host_task.md Signed-off-by: Sergey Kanaev <[email protected]>
1 parent d7ee359 commit f088e38

File tree

11 files changed

+454
-26
lines changed

11 files changed

+454
-26
lines changed

sycl/include/CL/sycl/detail/cg.hpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,16 @@ class CGInteropTask : public CG {
305305
class CGHostTask : public CG {
306306
public:
307307
std::unique_ptr<HostTask> MHostTask;
308+
// queue for host-interop task
309+
shared_ptr_class<detail::queue_impl> MQueue;
310+
// context for host-interop task
311+
shared_ptr_class<detail::context_impl> MContext;
308312
vector_class<ArgDesc> MArgs;
309313

310-
CGHostTask(std::unique_ptr<HostTask> HostTask, vector_class<ArgDesc> Args,
314+
CGHostTask(std::unique_ptr<HostTask> HostTask,
315+
std::shared_ptr<detail::queue_impl> Queue,
316+
std::shared_ptr<detail::context_impl> Context,
317+
vector_class<ArgDesc> Args,
311318
std::vector<std::vector<char>> ArgsStorage,
312319
std::vector<detail::AccessorImplPtr> AccStorage,
313320
std::vector<std::shared_ptr<const void>> SharedPtrStorage,
@@ -317,7 +324,8 @@ class CGHostTask : public CG {
317324
: CG(Type, std::move(ArgsStorage), std::move(AccStorage),
318325
std::move(SharedPtrStorage), std::move(Requirements),
319326
std::move(Events), std::move(loc)),
320-
MHostTask(std::move(HostTask)), MArgs(std::move(Args)) {}
327+
MHostTask(std::move(HostTask)), MQueue(Queue), MContext(Context),
328+
MArgs(std::move(Args)) {}
321329
};
322330

323331
class CGBarrier : public CG {

sycl/include/CL/sycl/detail/cg_types.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <CL/sycl/detail/kernel_desc.hpp>
1313
#include <CL/sycl/group.hpp>
1414
#include <CL/sycl/id.hpp>
15+
#include <CL/sycl/interop_handle.hpp>
1516
#include <CL/sycl/interop_handler.hpp>
1617
#include <CL/sycl/kernel.hpp>
1718
#include <CL/sycl/nd_item.hpp>
@@ -143,12 +144,17 @@ class InteropTask {
143144

144145
class HostTask {
145146
std::function<void()> MHostTask;
147+
std::function<void(interop_handle)> MInteropTask;
146148

147149
public:
148150
HostTask() : MHostTask([]() {}) {}
149151
HostTask(std::function<void()> &&Func) : MHostTask(Func) {}
152+
HostTask(std::function<void(interop_handle)> &&Func) : MInteropTask(Func) {}
153+
154+
bool isInteropTask() const { return !!MInteropTask; }
150155

151156
void call() { MHostTask(); }
157+
void call(interop_handle handle) { MInteropTask(handle); }
152158
};
153159

154160
// Class which stores specific lambda object.

sycl/include/CL/sycl/detail/sycl_mem_obj_i.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class SYCLMemObjI {
6969
// which is unavailable.
7070
shared_ptr_class<MemObjRecord> MRecord;
7171
friend class Scheduler;
72+
friend class ExecCGCommand;
7273
};
7374

7475
} // namespace detail

sycl/include/CL/sycl/handler.hpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <CL/sycl/detail/os_util.hpp>
2020
#include <CL/sycl/event.hpp>
2121
#include <CL/sycl/id.hpp>
22+
#include <CL/sycl/interop_handle.hpp>
2223
#include <CL/sycl/item.hpp>
2324
#include <CL/sycl/kernel.hpp>
2425
#include <CL/sycl/nd_item.hpp>
@@ -856,8 +857,22 @@ class __SYCL_EXPORT handler {
856857
}
857858

858859
template <typename FuncT>
859-
typename std::enable_if<detail::check_fn_signature<
860-
typename std::remove_reference<FuncT>::type, void()>::value>::type
860+
detail::enable_if_t<detail::check_fn_signature<
861+
detail::remove_reference_t<FuncT>, void()>::value>
862+
codeplay_host_task(FuncT Func) {
863+
throwIfActionIsCreated();
864+
865+
MNDRDesc.set(range<1>(1));
866+
MArgs = std::move(MAssociatedAccesors);
867+
868+
MHostTask.reset(new detail::HostTask(std::move(Func)));
869+
870+
MCGType = detail::CG::CODEPLAY_HOST_TASK;
871+
}
872+
873+
template <typename FuncT>
874+
detail::enable_if_t<detail::check_fn_signature<
875+
detail::remove_reference_t<FuncT>, void(interop_handle)>::value>
861876
codeplay_host_task(FuncT Func) {
862877
throwIfActionIsCreated();
863878

sycl/include/CL/sycl/interop_handle.hpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,13 @@ class interop_handle {
8484
template <backend BackendName = backend::opencl>
8585
auto get_native_queue() const noexcept ->
8686
typename interop<BackendName, queue>::type {
87+
#ifndef __SYCL_DEVICE_ONLY__
8788
return reinterpret_cast<typename interop<BackendName, queue>::type>(
8889
getNativeQueue());
90+
#else
91+
// we believe this won't be ever called on device side
92+
return nullptr;
93+
#endif
8994
}
9095

9196
/// Returns an underlying OpenCL device associated with the SYCL queue used
@@ -94,8 +99,13 @@ class interop_handle {
9499
template <backend BackendName = backend::opencl>
95100
auto get_native_device() const noexcept ->
96101
typename interop<BackendName, device>::type {
102+
#ifndef __SYCL_DEVICE_ONLY__
97103
return reinterpret_cast<typename interop<BackendName, device>::type>(
98104
getNativeDevice());
105+
#else
106+
// we believe this won't be ever called on device side
107+
return nullptr;
108+
#endif
99109
}
100110

101111
/// Returns an underlying OpenCL context associated with the SYCL queue used
@@ -104,14 +114,20 @@ class interop_handle {
104114
template <backend BackendName = backend::opencl>
105115
auto get_native_context() const noexcept ->
106116
typename interop<BackendName, context>::type {
117+
#ifndef __SYCL_DEVICE_ONLY__
107118
return reinterpret_cast<typename interop<BackendName, context>::type>(
108119
getNativeContext());
120+
#else
121+
// we believe this won't be ever called on device side
122+
return nullptr;
123+
#endif
109124
}
110125

111126
private:
127+
friend class detail::ExecCGCommand;
128+
friend class detail::DispatchHostTask;
112129
using ReqToMem = std::pair<detail::Requirement *, pi_mem>;
113130

114-
public:
115131
// TODO set c-tor private
116132
interop_handle(std::vector<ReqToMem> MemObjs,
117133
const std::shared_ptr<detail::queue_impl> &Queue,
@@ -131,10 +147,10 @@ class interop_handle {
131147
getNativeMem(Req));
132148
}
133149

134-
pi_native_handle getNativeMem(detail::Requirement *Req) const;
135-
pi_native_handle getNativeQueue() const;
136-
pi_native_handle getNativeDevice() const;
137-
pi_native_handle getNativeContext() const;
150+
__SYCL_EXPORT pi_native_handle getNativeMem(detail::Requirement *Req) const;
151+
__SYCL_EXPORT pi_native_handle getNativeQueue() const;
152+
__SYCL_EXPORT pi_native_handle getNativeDevice() const;
153+
__SYCL_EXPORT pi_native_handle getNativeContext() const;
138154

139155
std::shared_ptr<detail::queue_impl> MQueue;
140156
std::shared_ptr<detail::device_impl> MDevice;

sycl/source/detail/scheduler/commands.cpp

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88

99
#include <detail/error_handling/error_handling.hpp>
1010

11-
#include "CL/sycl/access/access.hpp"
11+
#include <CL/sycl/access/access.hpp>
1212
#include <CL/sycl/backend_types.hpp>
13+
#include <CL/sycl/detail/cg_types.hpp>
1314
#include <CL/sycl/detail/cl.h>
1415
#include <CL/sycl/detail/kernel_desc.hpp>
1516
#include <CL/sycl/detail/memory_manager.hpp>
@@ -159,6 +160,7 @@ getPiEvents(const std::vector<EventImplPtr> &EventImpls) {
159160

160161
class DispatchHostTask {
161162
ExecCGCommand *MThisCmd;
163+
std::vector<interop_handle::ReqToMem> MReqToMem;
162164

163165
void waitForEvents() const {
164166
std::map<const detail::plugin *, std::vector<EventImplPtr>>
@@ -187,7 +189,9 @@ class DispatchHostTask {
187189
}
188190

189191
public:
190-
DispatchHostTask(ExecCGCommand *ThisCmd) : MThisCmd{ThisCmd} {}
192+
DispatchHostTask(ExecCGCommand *ThisCmd,
193+
std::vector<interop_handle::ReqToMem> ReqToMem)
194+
: MThisCmd{ThisCmd}, MReqToMem(std::move(ReqToMem)) {}
191195

192196
void operator()() const {
193197
waitForEvents();
@@ -197,7 +201,15 @@ class DispatchHostTask {
197201
CGHostTask &HostTask = static_cast<CGHostTask &>(MThisCmd->getCG());
198202

199203
// we're ready to call the user-defined lambda now
200-
HostTask.MHostTask->call();
204+
if (HostTask.MHostTask->isInteropTask()) {
205+
interop_handle IH{MReqToMem, HostTask.MQueue,
206+
getSyclObjImpl(HostTask.MQueue->get_device()),
207+
HostTask.MQueue->getContextImplPtr()};
208+
209+
HostTask.MHostTask->call(IH);
210+
} else
211+
HostTask.MHostTask->call();
212+
201213
HostTask.MHostTask.reset();
202214

203215
// unblock user empty command here
@@ -1943,7 +1955,38 @@ cl_int ExecCGCommand::enqueueImp() {
19431955
}
19441956
}
19451957

1946-
MQueue->getThreadPool().submit<DispatchHostTask>(DispatchHostTask(this));
1958+
std::vector<interop_handle::ReqToMem> ReqToMem;
1959+
1960+
if (HostTask->MHostTask->isInteropTask()) {
1961+
// Extract the Mem Objects for all Requirements, to ensure they are
1962+
// available if a user asks for them inside the interop task scope
1963+
const std::vector<Requirement *> &HandlerReq = HostTask->MRequirements;
1964+
auto ReqToMemConv = [&ReqToMem, HostTask](Requirement *Req) {
1965+
const std::vector<AllocaCommandBase *> &AllocaCmds =
1966+
Req->MSYCLMemObj->MRecord->MAllocaCommands;
1967+
1968+
for (AllocaCommandBase *AllocaCmd : AllocaCmds)
1969+
if (HostTask->MQueue == AllocaCmd->getQueue()) {
1970+
auto MemArg =
1971+
reinterpret_cast<pi_mem>(AllocaCmd->getMemAllocation());
1972+
ReqToMem.emplace_back(std::make_pair(Req, MemArg));
1973+
1974+
return;
1975+
}
1976+
1977+
assert(false &&
1978+
"Can't get memory object due to no allocation available");
1979+
1980+
throw runtime_error(
1981+
"Can't get memory object due to no allocation available",
1982+
PI_INVALID_MEM_OBJECT);
1983+
};
1984+
std::for_each(std::begin(HandlerReq), std::end(HandlerReq), ReqToMemConv);
1985+
std::sort(std::begin(ReqToMem), std::end(ReqToMem));
1986+
}
1987+
1988+
MQueue->getThreadPool().submit<DispatchHostTask>(
1989+
DispatchHostTask(this, std::move(ReqToMem)));
19471990

19481991
MShouldCompleteEventIfPossible = false;
19491992

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,16 @@ Scheduler::GraphBuilder::addEmptyCmd(Command *Cmd, const std::vector<T *> &Reqs,
679679
return EmptyCmd;
680680
}
681681

682+
static bool isInteropHostTask(const std::unique_ptr<ExecCGCommand> &Cmd) {
683+
if (Cmd->getCG().getType() != CG::CGTYPE::CODEPLAY_HOST_TASK)
684+
return false;
685+
686+
const detail::CGHostTask &HT =
687+
static_cast<detail::CGHostTask &>(Cmd->getCG());
688+
689+
return HT.MHostTask->isInteropTask();
690+
}
691+
682692
Command *
683693
Scheduler::GraphBuilder::addCG(std::unique_ptr<detail::CG> CommandGroup,
684694
QueueImplPtr Queue) {
@@ -695,13 +705,29 @@ Scheduler::GraphBuilder::addCG(std::unique_ptr<detail::CG> CommandGroup,
695705
printGraphAsDot("before_addCG");
696706

697707
for (Requirement *Req : Reqs) {
698-
MemObjRecord *Record = getOrInsertMemObjRecord(Queue, Req);
699-
markModifiedIfWrite(Record, Req);
708+
MemObjRecord *Record = nullptr;
709+
AllocaCommandBase *AllocaCmd = nullptr;
710+
711+
bool isSameCtx = false;
712+
713+
{
714+
const QueueImplPtr &QueueForAlloca =
715+
isInteropHostTask(NewCmd)
716+
? static_cast<detail::CGHostTask &>(NewCmd->getCG()).MQueue
717+
: Queue;
718+
719+
Record = getOrInsertMemObjRecord(QueueForAlloca, Req);
720+
markModifiedIfWrite(Record, Req);
721+
722+
AllocaCmd = getOrCreateAllocaForReq(Record, Req, QueueForAlloca);
723+
724+
isSameCtx =
725+
sameCtx(QueueForAlloca->getContextImplPtr(), Record->MCurContext);
726+
}
700727

701-
AllocaCommandBase *AllocaCmd = getOrCreateAllocaForReq(Record, Req, Queue);
702728
// If there is alloca command we need to check if the latest memory is in
703729
// required context.
704-
if (sameCtx(Queue->getContextImplPtr(), Record->MCurContext)) {
730+
if (isSameCtx) {
705731
// If the memory is already in the required host context, check if the
706732
// required access mode is valid, remap if not.
707733
if (Record->MCurContext->is_host() &&
@@ -710,10 +736,24 @@ Scheduler::GraphBuilder::addCG(std::unique_ptr<detail::CG> CommandGroup,
710736
} else {
711737
// Cannot directly copy memory from OpenCL device to OpenCL device -
712738
// create two copies: device->host and host->device.
713-
if (!Queue->is_host() && !Record->MCurContext->is_host())
739+
bool NeedMemMoveToHost = false;
740+
auto MemMoveTargetQueue = Queue;
741+
742+
if (isInteropHostTask(NewCmd)) {
743+
const detail::CGHostTask &HT =
744+
static_cast<detail::CGHostTask &>(NewCmd->getCG());
745+
746+
if (HT.MQueue->getContextImplPtr() != Record->MCurContext) {
747+
NeedMemMoveToHost = true;
748+
MemMoveTargetQueue = HT.MQueue;
749+
}
750+
} else if (!Queue->is_host() && !Record->MCurContext->is_host())
751+
NeedMemMoveToHost = true;
752+
753+
if (NeedMemMoveToHost)
714754
insertMemoryMove(Record, Req,
715755
Scheduler::getInstance().getDefaultHostQueue());
716-
insertMemoryMove(Record, Req, Queue);
756+
insertMemoryMove(Record, Req, MemMoveTargetQueue);
717757
}
718758
std::set<Command *> Deps =
719759
findDepsForReq(Record, Req, Queue->getContextImplPtr());
@@ -927,10 +967,11 @@ void Scheduler::GraphBuilder::connectDepEvent(Command *const Cmd,
927967
{
928968
std::unique_ptr<detail::HostTask> HT(new detail::HostTask);
929969
std::unique_ptr<detail::CG> ConnectCG(new detail::CGHostTask(
930-
std::move(HT), /* Args = */ {}, /* ArgsStorage = */ {},
931-
/* AccStorage = */ {}, /* SharedPtrStorage = */ {},
932-
/* Requirements = */ {}, /* DepEvents = */ {DepEvent},
933-
CG::CODEPLAY_HOST_TASK, /* Payload */ {}));
970+
std::move(HT), /* Queue = */ {}, /* Context = */ {}, /* Args = */ {},
971+
/* ArgsStorage = */ {}, /* AccStorage = */ {},
972+
/* SharedPtrStorage = */ {}, /* Requirements = */ {},
973+
/* DepEvents = */ {DepEvent}, CG::CODEPLAY_HOST_TASK,
974+
/* Payload */ {}));
934975
ConnectCmd = new ExecCGCommand(
935976
std::move(ConnectCG), Scheduler::getInstance().getDefaultHostQueue());
936977
}

sycl/source/handler.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,10 @@ event handler::finalize() {
8585
break;
8686
case detail::CG::CODEPLAY_HOST_TASK:
8787
CommandGroup.reset(new detail::CGHostTask(
88-
std::move(MHostTask), std::move(MArgs), std::move(MArgsStorage),
89-
std::move(MAccStorage), std::move(MSharedPtrStorage),
90-
std::move(MRequirements), std::move(MEvents), MCGType, MCodeLoc));
88+
std::move(MHostTask), MQueue, MQueue->getContextImplPtr(),
89+
std::move(MArgs), std::move(MArgsStorage), std::move(MAccStorage),
90+
std::move(MSharedPtrStorage), std::move(MRequirements),
91+
std::move(MEvents), MCGType, MCodeLoc));
9192
break;
9293
case detail::CG::BARRIER:
9394
case detail::CG::BARRIER_WAITLIST:

sycl/test/abi/sycl_symbols_linux.dump

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3320,6 +3320,10 @@ _ZNK2cl4sycl13host_selectorclERKNS0_6deviceE
33203320
_ZNK2cl4sycl14exception_list3endEv
33213321
_ZNK2cl4sycl14exception_list4sizeEv
33223322
_ZNK2cl4sycl14exception_list5beginEv
3323+
_ZNK2cl4sycl14interop_handle12getNativeMemEPNS0_6detail16AccessorImplHostE
3324+
_ZNK2cl4sycl14interop_handle14getNativeQueueEv
3325+
_ZNK2cl4sycl14interop_handle15getNativeDeviceEv
3326+
_ZNK2cl4sycl14interop_handle16getNativeContextEv
33233327
_ZNK2cl4sycl15device_selector13select_deviceEv
33243328
_ZNK2cl4sycl15interop_handler12GetNativeMemEPNS0_6detail16AccessorImplHostE
33253329
_ZNK2cl4sycl15interop_handler14GetNativeQueueEv

0 commit comments

Comments
 (0)