Skip to content

Commit cb2efef

Browse files
authored
[SYCLLowerIR][SemaSYCL] Support indirect hierarchical parallelism (#14264)
This PR adds a missing feature in SYCL hierarchical parallelism support. Specifically, this PR adds support for the case when there are functions between parallel_for_work_group and parallel_for_work_item in the call stack. For example: void foo(sycl::group<1> group, ...) { group.parallel_for_work_item(range<1>(), [&](h_item<1> i) { ... }); } // ... cgh.parallel_for_work_group<class kernel>( range<1>(...), range<1>(...), [=](group<1> g) { foo(g, ...); }); --------- Signed-off-by: Sudarsanam, Arvind <[email protected]>
1 parent 2b48ab5 commit cb2efef

File tree

4 files changed

+118
-42
lines changed

4 files changed

+118
-42
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,30 @@ static bool isDeclaredInSYCLNamespace(const Decl *D) {
752752
return ND && ND->getName() == "sycl";
753753
}
754754

755+
static bool isSYCLPrivateMemoryVar(VarDecl *VD) {
756+
return SemaSYCL::isSyclType(VD->getType(), SYCLTypeAttr::private_memory);
757+
}
758+
759+
static void addScopeAttrToLocalVars(FunctionDecl &F) {
760+
for (Decl *D : F.decls()) {
761+
VarDecl *VD = dyn_cast<VarDecl>(D);
762+
763+
if (!VD || isa<ParmVarDecl>(VD) ||
764+
VD->getStorageDuration() != StorageDuration::SD_Automatic)
765+
continue;
766+
// Local variables of private_memory type in the WG scope still have WI
767+
// scope, all the rest - WG scope. Simple logic
768+
// "if no scope than it is WG scope" won't work, because compiler may add
769+
// locals not declared in user code (lambda object parameter, byval
770+
// arguments) which will result in alloca w/o any attribute, so need WI
771+
// scope too.
772+
SYCLScopeAttr::Level L = isSYCLPrivateMemoryVar(VD)
773+
? SYCLScopeAttr::Level::WorkItem
774+
: SYCLScopeAttr::Level::WorkGroup;
775+
VD->addAttr(SYCLScopeAttr::CreateImplicit(F.getASTContext(), L));
776+
}
777+
}
778+
755779
// This type does the heavy lifting for the management of device functions,
756780
// recursive function detection, and attribute collection for a single
757781
// kernel/external function. It walks the callgraph to find all functions that
@@ -801,12 +825,24 @@ class SingleDeviceFunctionTracker {
801825
// Note: Here, we assume that this is called from within a
802826
// parallel_for_work_group; it is undefined to call it otherwise.
803827
// We deliberately do not diagnose a violation.
828+
// The following changes have also been added:
829+
// 1. The function inside which the parallel_for_work_item exists is
830+
// marked with WorkGroup scope attribute, if not present already.
831+
// 2. The local variables inside the function are marked with appropriate
832+
// scope.
804833
if (CurrentDecl->getIdentifier() &&
805834
CurrentDecl->getIdentifier()->getName() == "parallel_for_work_item" &&
806835
isDeclaredInSYCLNamespace(CurrentDecl) &&
807836
!CurrentDecl->hasAttr<SYCLScopeAttr>()) {
808837
CurrentDecl->addAttr(SYCLScopeAttr::CreateImplicit(
809838
Parent.SemaSYCLRef.getASTContext(), SYCLScopeAttr::Level::WorkItem));
839+
FunctionDecl *Caller = CallStack.back();
840+
if (!Caller->hasAttr<SYCLScopeAttr>()) {
841+
Caller->addAttr(
842+
SYCLScopeAttr::CreateImplicit(Parent.SemaSYCLRef.getASTContext(),
843+
SYCLScopeAttr::Level::WorkGroup));
844+
addScopeAttrToLocalVars(*Caller);
845+
}
810846
}
811847

812848
// We previously thought we could skip this function if we'd seen it before,
@@ -999,30 +1035,6 @@ class MarkWIScopeFnVisitor : public RecursiveASTVisitor<MarkWIScopeFnVisitor> {
9991035
ASTContext &Ctx;
10001036
};
10011037

1002-
static bool isSYCLPrivateMemoryVar(VarDecl *VD) {
1003-
return SemaSYCL::isSyclType(VD->getType(), SYCLTypeAttr::private_memory);
1004-
}
1005-
1006-
static void addScopeAttrToLocalVars(CXXMethodDecl &F) {
1007-
for (Decl *D : F.decls()) {
1008-
VarDecl *VD = dyn_cast<VarDecl>(D);
1009-
1010-
if (!VD || isa<ParmVarDecl>(VD) ||
1011-
VD->getStorageDuration() != StorageDuration::SD_Automatic)
1012-
continue;
1013-
// Local variables of private_memory type in the WG scope still have WI
1014-
// scope, all the rest - WG scope. Simple logic
1015-
// "if no scope than it is WG scope" won't work, because compiler may add
1016-
// locals not declared in user code (lambda object parameter, byval
1017-
// arguments) which will result in alloca w/o any attribute, so need WI
1018-
// scope too.
1019-
SYCLScopeAttr::Level L = isSYCLPrivateMemoryVar(VD)
1020-
? SYCLScopeAttr::Level::WorkItem
1021-
: SYCLScopeAttr::Level::WorkGroup;
1022-
VD->addAttr(SYCLScopeAttr::CreateImplicit(F.getASTContext(), L));
1023-
}
1024-
}
1025-
10261038
/// Return method by name
10271039
static CXXMethodDecl *getMethodByName(const CXXRecordDecl *CRD,
10281040
StringRef MethodName) {

clang/test/CodeGenSYCL/sycl-pf-work-item.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: %clang_cc1 -fsycl-is-device -triple spir64-unknown-unknown -internal-isystem %S/Inputs -emit-llvm %s -o - | FileCheck %s
22
// This test checks if the parallel_for_work_item called indirecly from
33
// parallel_for_work_group gets the work_item_scope marker on it.
4+
// It also checks if the calling function gets the work_group_scope marker on it.
45
#include <sycl.hpp>
56

67
void foo(sycl::group<1> work_group) {
@@ -18,4 +19,5 @@ int main(int argc, char **argv) {
1819
return 0;
1920
}
2021

22+
// CHECK: define {{.*}} void {{.*}}foo{{.*}} !work_group_scope
2123
// CHECK: define {{.*}} void @{{.*}}sycl{{.*}}group{{.*}}parallel_for_work_item{{.*}}(ptr addrspace(4) noundef align 1 dereferenceable_or_null(1) %this) {{.*}}!work_item_scope {{.*}}!parallel_for_work_item

llvm/lib/SYCLLowerIR/LowerWGScope.cpp

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,6 @@
6565
// (1) - materialization of a PFWI object
6666
// (2) - "fixup" of the private variable address.
6767
//
68-
// TODO: add support for the case when there are other functions between
69-
// parallel_for_work_group and parallel_for_work_item in the call stack.
70-
// For example:
71-
//
72-
// void foo(sycl::group<1> group, ...) {
73-
// group.parallel_for_work_item(range<1>(), [&](h_item<1> i) { ... });
74-
// }
75-
// ...
76-
// cgh.parallel_for_work_group<class kernel>(
77-
// range<1>(...), range<1>(...), [=](group<1> g) {
78-
// foo(g, ...);
79-
// });
80-
//
8168
// TODO The approach employed by this pass generates lots of barriers and data
8269
// copying between private and local memory, which might not be efficient. There
8370
// are optimization opportunities listed below. Also other approaches can be
@@ -209,11 +196,36 @@ static bool isCallToAFuncMarkedWithMD(const Instruction *I, const char *MD) {
209196
return F && F->getMetadata(MD);
210197
}
211198

212-
// Checks is this is a call to parallel_for_work_item.
199+
// Recursively searches for a call to a function with work_group
200+
// metadata inside F.
201+
static bool hasCallToAFuncWithWGMetadata(Function &F) {
202+
for (auto &BB : F)
203+
for (auto &I : BB) {
204+
if (isCallToAFuncMarkedWithMD(&I, WG_SCOPE_MD))
205+
return true;
206+
const CallInst *Call = dyn_cast<CallInst>(&I);
207+
Function *F = dyn_cast_or_null<Function>(Call ? Call->getCalledFunction()
208+
: nullptr);
209+
if (F && hasCallToAFuncWithWGMetadata(*F))
210+
return true;
211+
}
212+
return false;
213+
}
214+
215+
// Checks if this is a call to parallel_for_work_item.
213216
static bool isPFWICall(const Instruction *I) {
214217
return isCallToAFuncMarkedWithMD(I, PFWI_MD);
215218
}
216219

220+
// Checks if F has any calls to function marked with PFWI_MD metadata.
221+
static bool hasPFWICall(Function &F) {
222+
for (auto &BB : F)
223+
for (auto &I : BB)
224+
if (isPFWICall(&I))
225+
return true;
226+
return false;
227+
}
228+
217229
// Checks if given instruction must be executed by all work items.
218230
static bool isWIScopeInst(const Instruction *I) {
219231
if (I->isTerminator())
@@ -425,6 +437,17 @@ static void copyBetweenPrivateAndShadow(Value *L, GlobalVariable *Shadow,
425437
}
426438
}
427439

440+
// Skip allocas, addrspacecasts associated with allocas and debug insts.
441+
static Instruction *getFirstInstToProcess(BasicBlock *BB) {
442+
Instruction *I = &BB->front();
443+
for (;
444+
I->getOpcode() == Instruction::Alloca ||
445+
I->getOpcode() == Instruction::AddrSpaceCast || I->isDebugOrPseudoInst();
446+
I = I->getNextNode()) {
447+
}
448+
return I;
449+
}
450+
428451
// Performs the following transformation for each basic block in the input map:
429452
//
430453
// BB:
@@ -462,7 +485,11 @@ static void materializeLocalsInWIScopeBlocksImpl(
462485
for (auto &P : BB2MatLocals) {
463486
// generate LeaderBB and private<->shadow copies in proper BBs
464487
BasicBlock *LeaderBB = P.first;
465-
BasicBlock *BB = LeaderBB->splitBasicBlock(&LeaderBB->front(), "LeaderMat");
488+
// Skip allocas, addrspacecasts associated with allocas and debug insts.
489+
// Alloca instructions and it's associated instructions must be in the
490+
// beginning of the function.
491+
Instruction *LeaderBBFront = getFirstInstToProcess(LeaderBB);
492+
BasicBlock *BB = LeaderBB->splitBasicBlock(LeaderBBFront, "LeaderMat");
466493
// Add a barrier to the original block:
467494
Instruction *At =
468495
spirv::genWGBarrier(*BB->getFirstNonPHI(), TT)->getNextNode();
@@ -476,7 +503,8 @@ static void materializeLocalsInWIScopeBlocksImpl(
476503
// fill the leader BB:
477504
// fetch data from leader's private copy (which is always up to date) into
478505
// the corresponding shadow variable
479-
Builder.SetInsertPoint(&LeaderBB->front());
506+
LeaderBBFront = getFirstInstToProcess(LeaderBB);
507+
Builder.SetInsertPoint(LeaderBBFront);
480508
copyBetweenPrivateAndShadow(L, Shadow, Builder, true /*private->shadow*/);
481509
// store data to the local variable - effectively "refresh" the value of
482510
// the local in each work item in the work group
@@ -485,8 +513,8 @@ static void materializeLocalsInWIScopeBlocksImpl(
485513
false /*shadow->private*/);
486514
}
487515
// now generate the TestBB and the leader WI guard
488-
BasicBlock *TestBB =
489-
LeaderBB->splitBasicBlock(&LeaderBB->front(), "TestMat");
516+
LeaderBBFront = getFirstInstToProcess(LeaderBB);
517+
BasicBlock *TestBB = LeaderBB->splitBasicBlock(LeaderBBFront, "TestMat");
490518
std::swap(TestBB, LeaderBB);
491519
guardBlockWithIsLeaderCheck(TestBB, LeaderBB, BB, At->getDebugLoc(), TT);
492520
}
@@ -752,6 +780,10 @@ PreservedAnalyses SYCLLowerWGScopePass::run(Function &F,
752780
FunctionAnalysisManager &FAM) {
753781
if (!F.getMetadata(WG_SCOPE_MD))
754782
return PreservedAnalyses::all();
783+
// If a function does not have any PFWI calls and it has calls to a function
784+
// that has work_group metadata, then we do not need to lower such functions.
785+
if (!hasPFWICall(F) && hasCallToAFuncWithWGMetadata(F))
786+
return PreservedAnalyses::all();
755787
LLVM_DEBUG(llvm::dbgs() << "Function name: " << F.getName() << "\n");
756788
const auto &TT = llvm::Triple(F.getParent()->getTargetTriple());
757789
// Ranges of "side effect" instructions
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//==- hier_par_indirect.cpp --- hierarchical parallelism test for WG scope--==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// RUN: %{build} -o %t.out
10+
// RUN: %{run} %t.out
11+
12+
// This test checks correctness of hierarchical kernel execution when the work
13+
// item code is not directly inside work group scope.
14+
15+
#include <iostream>
16+
#include <sycl/detail/core.hpp>
17+
18+
void __attribute__((noinline)) foo(sycl::group<1> work_group) {
19+
work_group.parallel_for_work_item([&](sycl::h_item<1> index) {});
20+
}
21+
22+
int main(int argc, char **argv) {
23+
sycl::queue q;
24+
q.submit([&](sycl::handler &cgh) {
25+
cgh.parallel_for_work_group(sycl::range<1>{1}, sycl::range<1>{128},
26+
([=](sycl::group<1> wGroup) { foo(wGroup); }));
27+
}).wait();
28+
std::cout << "test passed" << std::endl;
29+
return 0;
30+
}

0 commit comments

Comments
 (0)