Skip to content

Commit 2f253a9

Browse files
authored
[SYCL][Fusion] Restrict types of fusable command groups (#12556)
Only allow command groups of `Kernel` type. Do not add other kind of command groups to the fusable graph when found, showing a descriptive warning. --------- Signed-off-by: Victor Perez <[email protected]>
1 parent 18d6471 commit 2f253a9

File tree

5 files changed

+206
-58
lines changed

5 files changed

+206
-58
lines changed

sycl/source/detail/jit_compiler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ jit_compiler::fuseKernels(QueueImplPtr Queue,
672672
for (auto &RawCmd : InputKernels) {
673673
auto *KernelCmd = static_cast<ExecCGCommand *>(RawCmd);
674674
auto &CG = KernelCmd->getCG();
675-
assert(CG.getType() == CG::Kernel);
675+
assert(KernelCmd->isFusable());
676676
auto *KernelCG = static_cast<CGExecKernel *>(&CG);
677677

678678
auto KernelName = KernelCG->MKernelName;

sycl/source/detail/scheduler/commands.cpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,12 @@ bool Command::isHostTask() const {
299299
CG::CGTYPE::CodeplayHostTask);
300300
}
301301

302+
bool Command::isFusable() const {
303+
return (MType == CommandType::RUN_CG) &&
304+
((static_cast<const ExecCGCommand *>(this))->getCG().getType() ==
305+
CG::CGTYPE::Kernel);
306+
}
307+
302308
static void flushCrossQueueDeps(const std::vector<EventImplPtr> &EventImpls,
303309
const QueueImplPtr &Queue) {
304310
for (auto &EventImpl : EventImpls) {
@@ -1825,7 +1831,7 @@ void UpdateHostRequirementCommand::emitInstrumentationData() {
18251831
#endif
18261832
}
18271833

1828-
static std::string cgTypeToString(detail::CG::CGTYPE Type) {
1834+
static std::string_view cgTypeToString(detail::CG::CGTYPE Type) {
18291835
switch (Type) {
18301836
case detail::CG::Kernel:
18311837
return "Kernel";
@@ -1845,6 +1851,10 @@ static std::string cgTypeToString(detail::CG::CGTYPE Type) {
18451851
case detail::CG::CopyPtrToAcc:
18461852
return "copy ptr to acc";
18471853
break;
1854+
case detail::CG::Barrier:
1855+
return "barrier";
1856+
case detail::CG::BarrierWaitlist:
1857+
return "barrier waitlist";
18481858
case detail::CG::CopyUSM:
18491859
return "copy usm";
18501860
break;
@@ -1863,6 +1873,8 @@ static std::string cgTypeToString(detail::CG::CGTYPE Type) {
18631873
case detail::CG::Fill2DUSM:
18641874
return "fill 2d usm";
18651875
break;
1876+
case detail::CG::AdviseUSM:
1877+
return "advise usm";
18661878
case detail::CG::Memset2DUSM:
18671879
return "memset 2d usm";
18681880
break;
@@ -1872,6 +1884,16 @@ static std::string cgTypeToString(detail::CG::CGTYPE Type) {
18721884
case detail::CG::CopyFromDeviceGlobal:
18731885
return "copy from device_global";
18741886
break;
1887+
case detail::CG::ReadWriteHostPipe:
1888+
return "read_write host pipe";
1889+
case detail::CG::ExecCommandBuffer:
1890+
return "exec command buffer";
1891+
case detail::CG::CopyImage:
1892+
return "copy image";
1893+
case detail::CG::SemaphoreWait:
1894+
return "semaphore wait";
1895+
case detail::CG::SemaphoreSignal:
1896+
return "semaphore signal";
18751897
default:
18761898
return "unknown";
18771899
break;
@@ -2102,7 +2124,7 @@ void ExecCGCommand::emitInstrumentationData() {
21022124
KernelCG->getKernelName(), MAddress, FromSource);
21032125
} break;
21042126
default:
2105-
KernelName = cgTypeToString(MCommandGroup->getType());
2127+
KernelName = getTypeString();
21062128
break;
21072129
}
21082130

@@ -2150,7 +2172,7 @@ void ExecCGCommand::printDot(std::ostream &Stream) const {
21502172
break;
21512173
}
21522174
default:
2153-
Stream << "CG type: " << cgTypeToString(MCommandGroup->getType()) << "\\n";
2175+
Stream << "CG type: " << getTypeString() << "\\n";
21542176
break;
21552177
}
21562178

@@ -2165,6 +2187,10 @@ void ExecCGCommand::printDot(std::ostream &Stream) const {
21652187
}
21662188
}
21672189

2190+
std::string_view ExecCGCommand::getTypeString() const {
2191+
return cgTypeToString(MCommandGroup->getType());
2192+
}
2193+
21682194
// SYCL has a parallel_for_work_group variant where the only NDRange
21692195
// characteristics set by a user is the number of work groups. This does not
21702196
// map to the OpenCL clEnqueueNDRangeAPI, which requires global work size to

sycl/source/detail/scheduler/commands.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ class Command {
244244

245245
bool isHostTask() const;
246246

247+
bool isFusable() const;
248+
247249
protected:
248250
QueueImplPtr MQueue;
249251
EventImplPtr MEvent;
@@ -648,6 +650,7 @@ class ExecCGCommand : public Command {
648650

649651
void printDot(std::ostream &Stream) const final;
650652
void emitInstrumentationData() final;
653+
std::string_view getTypeString() const;
651654

652655
detail::CG &getCG() const { return *MCommandGroup; }
653656

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 64 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "detail/config.hpp"
1010
#include <detail/context_impl.hpp>
1111
#include <detail/event_impl.hpp>
12+
#include <sstream>
1213
#include <sycl/feature_test.hpp>
1314
#if SYCL_EXT_CODEPLAY_KERNEL_FUSION
1415
#include <detail/jit_compiler.hpp>
@@ -949,66 +950,75 @@ Scheduler::GraphBuildResult Scheduler::GraphBuilder::addCG(
949950
if (!NewCmd)
950951
throw runtime_error("Out of host memory", PI_ERROR_OUT_OF_HOST_MEMORY);
951952

952-
// Host tasks cannot participate in fusion. They take the regular route. If
953-
// they create any requirement or event dependency on any of the kernels in
954-
// the fusion list, this will lead to cancellation of the fusion in the
955-
// GraphProcessor.
953+
// Only device kernel command groups can participate in fusion. Otherwise,
954+
// command groups take the regular route. If they create any requirement or
955+
// event dependency on any of the kernels in the fusion list, this will lead
956+
// to cancellation of the fusion in the GraphProcessor.
956957
auto QUniqueID = std::hash<sycl::detail::queue_impl *>()(Queue.get());
957-
if (isInFusionMode(QUniqueID) && !NewCmd->isHostTask()) {
958-
auto *FusionCmd = findFusionList(QUniqueID)->second.get();
959-
960-
bool dependsOnFusion = false;
961-
for (auto Ev = Events.begin(); Ev != Events.end();) {
962-
auto *EvDepCmd = static_cast<Command *>((*Ev)->getCommand());
963-
if (!EvDepCmd) {
964-
continue;
965-
}
966-
// Handle event dependencies on any commands part of another active
967-
// fusion.
968-
if (EvDepCmd->getQueue() != Queue && isPartOfActiveFusion(EvDepCmd)) {
969-
printFusionWarning("Aborting fusion because of event dependency from a "
970-
"different fusion");
971-
cancelFusion(EvDepCmd->getQueue(), ToEnqueue);
972-
}
973-
// Check if this command depends on the placeholder command for the fusion
974-
// itself participates in.
975-
if (EvDepCmd == FusionCmd) {
976-
Ev = Events.erase(Ev);
977-
dependsOnFusion = true;
978-
} else {
979-
++Ev;
958+
if (isInFusionMode(QUniqueID)) {
959+
if (NewCmd->isFusable()) {
960+
auto *FusionCmd = findFusionList(QUniqueID)->second.get();
961+
962+
bool dependsOnFusion = false;
963+
for (auto Ev = Events.begin(); Ev != Events.end();) {
964+
auto *EvDepCmd = static_cast<Command *>((*Ev)->getCommand());
965+
if (!EvDepCmd) {
966+
continue;
967+
}
968+
// Handle event dependencies on any commands part of another active
969+
// fusion.
970+
if (EvDepCmd->getQueue() != Queue && isPartOfActiveFusion(EvDepCmd)) {
971+
printFusionWarning(
972+
"Aborting fusion because of event dependency from a "
973+
"different fusion");
974+
cancelFusion(EvDepCmd->getQueue(), ToEnqueue);
975+
}
976+
// Check if this command depends on the placeholder command for the
977+
// fusion itself participates in.
978+
if (EvDepCmd == FusionCmd) {
979+
Ev = Events.erase(Ev);
980+
dependsOnFusion = true;
981+
} else {
982+
++Ev;
983+
}
980984
}
981-
}
982985

983-
// If this command has an explicit event dependency on the placeholder
984-
// command for this fusion (because it used depends_on on the event returned
985-
// by submitting another kernel to this fusion earlier), add a dependency on
986-
// all the commands in the fusion list so far.
987-
if (dependsOnFusion) {
988-
for (auto *Cmd : FusionCmd->getFusionList()) {
989-
Events.push_back(Cmd->getEvent());
986+
// If this command has an explicit event dependency on the placeholder
987+
// command for this fusion (because it used depends_on on the event
988+
// returned by submitting another kernel to this fusion earlier), add a
989+
// dependency on all the commands in the fusion list so far.
990+
if (dependsOnFusion) {
991+
for (auto *Cmd : FusionCmd->getFusionList()) {
992+
Events.push_back(Cmd->getEvent());
993+
}
990994
}
991-
}
992995

993-
// Add the kernel to the graph, but delay the enqueue of any auxiliary
994-
// commands (e.g., allocations) resulting from that process by adding them
995-
// to the list of auxiliary commands of the fusion command.
996-
createGraphForCommand(NewCmd.get(), NewCmd->getCG(),
997-
isInteropHostTask(NewCmd.get()), Reqs, Events, Queue,
998-
FusionCmd->auxiliaryCommands());
999-
1000-
// Set the fusion command, so we recognize when another command depends on a
1001-
// kernel in the fusion list.
1002-
FusionCmd->addToFusionList(NewCmd.get());
1003-
NewCmd->MFusionCmd = FusionCmd;
1004-
std::vector<Command *> ToCleanUp;
1005-
// Add an event dependency from the fusion placeholder command to the new
1006-
// kernel.
1007-
auto ConnectionCmd = FusionCmd->addDep(NewCmd->getEvent(), ToCleanUp);
1008-
if (ConnectionCmd) {
1009-
FusionCmd->auxiliaryCommands().push_back(ConnectionCmd);
996+
// Add the kernel to the graph, but delay the enqueue of any auxiliary
997+
// commands (e.g., allocations) resulting from that process by adding them
998+
// to the list of auxiliary commands of the fusion command.
999+
createGraphForCommand(NewCmd.get(), NewCmd->getCG(),
1000+
isInteropHostTask(NewCmd.get()), Reqs, Events,
1001+
Queue, FusionCmd->auxiliaryCommands());
1002+
1003+
// Set the fusion command, so we recognize when another command depends on
1004+
// a kernel in the fusion list.
1005+
FusionCmd->addToFusionList(NewCmd.get());
1006+
NewCmd->MFusionCmd = FusionCmd;
1007+
std::vector<Command *> ToCleanUp;
1008+
// Add an event dependency from the fusion placeholder command to the new
1009+
// kernel.
1010+
auto ConnectionCmd = FusionCmd->addDep(NewCmd->getEvent(), ToCleanUp);
1011+
if (ConnectionCmd) {
1012+
FusionCmd->auxiliaryCommands().push_back(ConnectionCmd);
1013+
}
1014+
return {NewCmd.release(), FusionCmd->getEvent(), false};
1015+
} else {
1016+
std::string s;
1017+
std::stringstream ss(s);
1018+
ss << "Not fusing '" << NewCmd->getTypeString()
1019+
<< "' command group. Can only fuse device kernel command groups.";
1020+
printFusionWarning(ss.str());
10101021
}
1011-
return {NewCmd.release(), FusionCmd->getEvent(), false};
10121022
}
10131023
createGraphForCommand(NewCmd.get(), NewCmd->getCG(),
10141024
isInteropHostTask(NewCmd.get()), Reqs, Events, Queue,
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// RUN: %{build} -fsycl-embed-ir -o %t.out
2+
// RUN: env SYCL_RT_WARNING_LEVEL=2 %{run} %t.out 2>&1 | FileCheck %s
3+
4+
// Test non-kernel device command groups are not fused
5+
6+
#include <sycl/sycl.hpp>
7+
8+
using namespace sycl;
9+
10+
int main() {
11+
constexpr size_t dataSize = 512;
12+
constexpr float Pattern{10};
13+
14+
queue q{ext::codeplay::experimental::property::queue::enable_fusion{}};
15+
ext::codeplay::experimental::fusion_wrapper fw(q);
16+
17+
constexpr size_t count = 64;
18+
auto *dst = malloc_device<float>(count, q);
19+
auto *src = malloc_device<float>(count, q);
20+
21+
{
22+
// CHECK: Not fusing 'copy acc to ptr' command group. Can only fuse device kernel command groups.
23+
buffer<float> src(dataSize);
24+
std::shared_ptr<float> dst(new float[dataSize]);
25+
fw.start_fusion();
26+
q.submit([&](handler &cgh) {
27+
accessor acc(src, cgh, read_only);
28+
cgh.copy(acc, dst);
29+
});
30+
fw.complete_fusion();
31+
}
32+
33+
{
34+
// CHECK: Not fusing 'copy ptr to acc' command group. Can only fuse device kernel command groups.
35+
buffer<float> dst(dataSize);
36+
std::shared_ptr<float> src(new float[dataSize]);
37+
fw.start_fusion();
38+
q.submit([&](handler &cgh) {
39+
accessor acc(dst, cgh, write_only);
40+
cgh.copy(src, acc);
41+
});
42+
fw.complete_fusion();
43+
}
44+
45+
{
46+
// CHECK: Not fusing 'copy acc to acc' command group. Can only fuse device kernel command groups.
47+
buffer<float> dst(dataSize);
48+
buffer<float> src(dataSize);
49+
fw.start_fusion();
50+
q.submit([&](handler &cgh) {
51+
accessor acc0(src, cgh, read_only);
52+
accessor acc1(dst, cgh, write_only);
53+
cgh.copy(acc0, acc1);
54+
});
55+
fw.complete_fusion();
56+
}
57+
58+
{
59+
// CHECK: Not fusing 'barrier' command group. Can only fuse device kernel command groups.
60+
fw.start_fusion();
61+
q.submit([&](handler &cgh) { cgh.ext_oneapi_barrier(); });
62+
fw.complete_fusion();
63+
}
64+
65+
{
66+
// CHECK: Not fusing 'barrier waitlist' command group. Can only fuse device kernel command groups.
67+
buffer<float> dst(dataSize);
68+
buffer<float> src(dataSize);
69+
std::vector<event> event_list;
70+
event_list.push_back(q.submit([&](handler &cgh) {
71+
accessor acc0(src, cgh, read_only);
72+
accessor acc1(dst, cgh, write_only);
73+
cgh.copy(acc0, acc1);
74+
}));
75+
fw.start_fusion();
76+
q.submit([&](handler &cgh) { cgh.ext_oneapi_barrier(event_list); });
77+
fw.complete_fusion();
78+
}
79+
80+
{
81+
// CHECK: Not fusing 'fill' command group. Can only fuse device kernel command groups.
82+
buffer<float> dst(dataSize);
83+
fw.start_fusion();
84+
q.submit([&](handler &cgh) {
85+
accessor acc(dst, cgh, write_only);
86+
cgh.fill(acc, Pattern);
87+
});
88+
fw.complete_fusion();
89+
}
90+
91+
{
92+
// CHECK: Not fusing 'copy usm' command group. Can only fuse device kernel command groups.
93+
fw.start_fusion();
94+
q.submit([&](handler &cgh) { cgh.memcpy(dst, src, count); });
95+
fw.complete_fusion();
96+
}
97+
98+
{
99+
// CHECK: Not fusing 'fill usm' command group. Can only fuse device kernel command groups.
100+
fw.start_fusion();
101+
q.submit([&](handler &cgh) {
102+
cgh.memset(dst, static_cast<int>(Pattern), count);
103+
});
104+
fw.complete_fusion();
105+
}
106+
107+
free(src, q);
108+
free(dst, q);
109+
}

0 commit comments

Comments
 (0)