Skip to content

Commit 120b4b5

Browse files
committed
[NFC][SYCL] Refactor kernel wrapper generation.
Signed-off-by: Vladimir Lazarev <[email protected]>
1 parent f1d20ee commit 120b4b5

File tree

5 files changed

+77
-63
lines changed

5 files changed

+77
-63
lines changed

clang/include/clang/Sema/Sema.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10852,7 +10852,7 @@ class Sema {
1085210852
void AddSyclKernel(Decl * d) { SyclKernel.push_back(d); }
1085310853
SmallVector<Decl*, 4> &SyclKernels() { return SyclKernel; }
1085410854

10855-
void ConstructSYCLKernel(CXXMemberCallExpr* e);
10855+
void ConstructSYCLKernel(FunctionDecl* KernelHelper);
1085610856
};
1085710857

1085810858
/// RAII object that enters a new expression evaluation context.

clang/lib/Sema/SemaOverload.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13012,15 +13012,6 @@ Sema::BuildCallToMemberFunction(Scope *S, Expr *MemExprE,
1301213012
CXXMemberCallExpr::Create(Context, MemExprE, Args, ResultType, VK,
1301313013
RParenLoc, Proto->getNumParams());
1301413014

13015-
if (getLangOpts().SYCL) {
13016-
auto Func = TheCall->getMethodDecl();
13017-
auto Name = Func->getQualifiedNameAsString();
13018-
if (Name == "cl::sycl::handler::parallel_for" ||
13019-
Name == "cl::sycl::handler::single_task") {
13020-
ConstructSYCLKernel(TheCall);
13021-
}
13022-
}
13023-
1302413015
// Check for a valid return type.
1302513016
if (CheckCallReturnType(Method->getReturnType(), MemExpr->getMemberLoc(),
1302613017
TheCall, Method))

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,44 @@
1212
#include "clang/AST/AST.h"
1313
#include "clang/Sema/Sema.h"
1414
#include "llvm/ADT/SmallVector.h"
15+
#include "TreeTransform.h"
1516

1617
using namespace clang;
1718

18-
LambdaExpr *getBodyAsLambda(CXXMemberCallExpr *e) {
19-
auto LastArg = e->getArg(e->getNumArgs() - 1);
20-
return dyn_cast<LambdaExpr>(LastArg);
19+
typedef llvm::DenseMap<DeclaratorDecl *, DeclaratorDecl *> DeclMap;
20+
21+
class KernelBodyTransform : public TreeTransform<KernelBodyTransform> {
22+
public:
23+
KernelBodyTransform(llvm::DenseMap<DeclaratorDecl *, DeclaratorDecl *> &Map,
24+
Sema &S)
25+
: TreeTransform<KernelBodyTransform>(S), DMap(Map), SemaRef(S) {}
26+
bool AlwaysRebuild() { return true; }
27+
28+
ExprResult TransformDeclRefExpr(DeclRefExpr *DRE) {
29+
auto Ref = dyn_cast<DeclaratorDecl>(DRE->getDecl());
30+
if (Ref) {
31+
auto NewDecl = DMap[Ref];
32+
if (NewDecl) {
33+
return DeclRefExpr::Create(
34+
SemaRef.getASTContext(), DRE->getQualifierLoc(),
35+
DRE->getTemplateKeywordLoc(), NewDecl, false, DRE->getNameInfo(),
36+
NewDecl->getType(), DRE->getValueKind());
37+
}
38+
}
39+
return DRE;
40+
}
41+
42+
private:
43+
DeclMap DMap;
44+
Sema &SemaRef;
45+
};
46+
47+
CXXRecordDecl* getBodyAsLambda(FunctionDecl *FD) {
48+
auto FirstArg = (*FD->param_begin());
49+
if (FirstArg)
50+
if (FirstArg->getType()->getAsCXXRecordDecl()->isLambda())
51+
return FirstArg->getType()->getAsCXXRecordDecl();
52+
return nullptr;
2153
}
2254

2355
FunctionDecl *CreateSYCLKernelFunction(ASTContext &Context, StringRef Name,
@@ -54,17 +86,16 @@ FunctionDecl *CreateSYCLKernelFunction(ASTContext &Context, StringRef Name,
5486
return Result;
5587
}
5688

57-
CompoundStmt *CreateSYCLKernelBody(Sema &S, CXXMemberCallExpr *e,
89+
CompoundStmt *CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelHelper,
5890
DeclContext *DC) {
5991

6092
llvm::SmallVector<Stmt *, 16> BodyStmts;
6193

6294
// TODO: case when kernel is functor
6395
// TODO: possible refactoring when functor case will be completed
64-
LambdaExpr *LE = getBodyAsLambda(e);
65-
if (LE) {
96+
CXXRecordDecl *LC = getBodyAsLambda(KernelHelper);
97+
if (LC) {
6698
// Create Lambda object
67-
CXXRecordDecl *LC = LE->getLambdaClass();
6899
auto LambdaVD = VarDecl::Create(
69100
S.Context, DC, SourceLocation(), SourceLocation(), LC->getIdentifier(),
70101
QualType(LC->getTypeForDecl(), 0), LC->getLambdaTypeInfo(), SC_None);
@@ -137,43 +168,23 @@ CompoundStmt *CreateSYCLKernelBody(Sema &S, CXXMemberCallExpr *e,
137168
TargetFuncParam++;
138169
}
139170

140-
// Create Lambda operator () call
141-
FunctionDecl *LO = LE->getCallOperator();
142-
ArrayRef<ParmVarDecl *> Args = LO->parameters();
143-
llvm::SmallVector<Expr *, 16> ParamStmts(1);
144-
ParamStmts[0] = dyn_cast<Expr>(LambdaDRE);
145-
146-
// Collect arguments for () operator
147-
for (auto Arg : Args) {
148-
QualType ArgType = Arg->getOriginalType();
149-
// Declare variable for parameter and pass it to call
150-
auto param_VD =
151-
VarDecl::Create(S.Context, DC, SourceLocation(), SourceLocation(),
152-
Arg->getIdentifier(), ArgType,
153-
S.Context.getTrivialTypeSourceInfo(ArgType), SC_None);
154-
Stmt *param_DS = new (S.Context)
155-
DeclStmt(DeclGroupRef(param_VD), SourceLocation(), SourceLocation());
156-
BodyStmts.push_back(param_DS);
157-
auto DRE = DeclRefExpr::Create(S.Context, NestedNameSpecifierLoc(),
158-
SourceLocation(), param_VD, false,
159-
DeclarationNameInfo(), ArgType, VK_LValue);
160-
Expr *Res = ImplicitCastExpr::Create(
161-
S.Context, ArgType, CK_LValueToRValue, DRE, nullptr, VK_RValue);
162-
ParamStmts.push_back(Res);
163-
}
171+
// In function from headers lambda is function parameter, we need
172+
// to replace all refs to this lambda with our vardecl.
173+
// I used TreeTransform here, but I'm not sure that it is good solution
174+
// Also I used map and I'm not sure about it too.
175+
Stmt* FunctionBody = KernelHelper->getBody();
176+
DeclMap DMap;
177+
ParmVarDecl* LambdaParam = *(KernelHelper->param_begin());
178+
// DeclRefExpr with valid source location but with decl which is not marked
179+
// as used is invalid.
180+
LambdaVD->setIsUsed();
181+
DMap[LambdaParam] = LambdaVD;
182+
// Without PushFunctionScope I had segfault. Maybe we also need to do pop.
183+
S.PushFunctionScope();
184+
KernelBodyTransform KBT(DMap, S);
185+
Stmt* NewBody = KBT.TransformStmt(FunctionBody).get();
186+
BodyStmts.push_back(NewBody);
164187

165-
// Create ref for call operator
166-
DeclRefExpr *DRE = new (S.Context)
167-
DeclRefExpr(S.Context, LO, false, LO->getType(), VK_LValue,
168-
SourceLocation());
169-
QualType ResultTy = LO->getReturnType();
170-
ExprValueKind VK = Expr::getValueKindForType(ResultTy);
171-
ResultTy = ResultTy.getNonLValueExprType(S.Context);
172-
173-
CXXOperatorCallExpr *TheCall = CXXOperatorCallExpr::Create(
174-
S.Context, OO_Call, DRE, ParamStmts, ResultTy, VK, SourceLocation(),
175-
FPOptions(), clang::CallExpr::ADLCallKind::NotADL );
176-
BodyStmts.push_back(TheCall);
177188
}
178189
return CompoundStmt::Create(S.Context, BodyStmts, SourceLocation(),
179190
SourceLocation());
@@ -222,9 +233,9 @@ void BuildArgTys(ASTContext &Context,
222233
}
223234
}
224235

225-
void Sema::ConstructSYCLKernel(CXXMemberCallExpr *e) {
236+
void Sema::ConstructSYCLKernel(FunctionDecl *KernelHelper) {
226237
// TODO: Case when kernel is functor
227-
LambdaExpr *LE = getBodyAsLambda(e);
238+
CXXRecordDecl *LE = getBodyAsLambda(KernelHelper);
228239
if (LE) {
229240

230241
llvm::SmallVector<DeclaratorDecl *, 16> ArgDecls;
@@ -238,9 +249,8 @@ void Sema::ConstructSYCLKernel(CXXMemberCallExpr *e) {
238249
BuildArgTys(getASTContext(), ArgDecls, NewArgDecls, ArgTys);
239250

240251
// Get Name for our kernel.
241-
FunctionDecl *FuncDecl = e->getMethodDecl();
242252
const TemplateArgumentList *TemplateArgs =
243-
FuncDecl->getTemplateSpecializationArgs();
253+
KernelHelper->getTemplateSpecializationArgs();
244254
QualType KernelNameType = TemplateArgs->get(0).getAsType();
245255
std::string Name = KernelNameType.getBaseTypeIdentifier()->getName().str();
246256

@@ -256,7 +266,7 @@ void Sema::ConstructSYCLKernel(CXXMemberCallExpr *e) {
256266
FunctionDecl *SYCLKernel =
257267
CreateSYCLKernelFunction(getASTContext(), Name, ArgTys, NewArgDecls);
258268

259-
CompoundStmt *SYCLKernelBody = CreateSYCLKernelBody(*this, e, SYCLKernel);
269+
CompoundStmt *SYCLKernelBody = CreateSYCLKernelBody(*this, KernelHelper, SYCLKernel);
260270
SYCLKernel->setBody(SYCLKernelBody);
261271

262272
AddSyclKernel(SYCLKernel);

clang/lib/Sema/SemaTemplateInstantiateDecl.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5231,14 +5231,28 @@ void Sema::PerformPendingInstantiations(bool LocalOnly) {
52315231
Function, [this, Inst, DefinitionRequired](FunctionDecl *CurFD) {
52325232
InstantiateFunctionDefinition(/*FIXME:*/ Inst.second, CurFD, true,
52335233
DefinitionRequired, true);
5234-
if (CurFD->isDefined())
5234+
if (CurFD->isDefined()) {
5235+
// Because all SYCL kernel functions are template functions - they
5236+
// have deferred instantination. We need bodies of these functions
5237+
// so we are checking for SYCL kernel attribute after instantination.
5238+
if (getLangOpts().SYCL && CurFD->hasAttr<SYCLKernelAttr>()) {
5239+
ConstructSYCLKernel(CurFD);
5240+
}
52355241
CurFD->setInstantiationIsPending(false);
5242+
}
52365243
});
52375244
} else {
52385245
InstantiateFunctionDefinition(/*FIXME:*/ Inst.second, Function, true,
52395246
DefinitionRequired, true);
5240-
if (Function->isDefined())
5247+
if (Function->isDefined()) {
5248+
// Because all SYCL kernel functions are template functions - they
5249+
// have deferred instantination. We need bodies of these functions
5250+
// so we are checking for SYCL kernel attribute after instantination.
5251+
if (getLangOpts().SYCL && Function->hasAttr<SYCLKernelAttr>()) {
5252+
ConstructSYCLKernel(Function);
5253+
}
52415254
Function->setInstantiationIsPending(false);
5255+
}
52425256
}
52435257
continue;
52445258
}

clang/test/CodeGenSYCL/kernel-with-id.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@ int main() {
1616

1717
deviceQueue.submit([&](cl::sycl::handler &cgh) {
1818
auto accessorA = bufferA.template get_access<cl::sycl::access::mode::read_write>(cgh);
19-
// CHECK: %wiID = alloca %"struct.cl::sycl::id", align 8
2019
// CHECK: call spir_func void @_ZN2cl4sycl8accessorIiLi1ELNS0_6access4modeE1026ELNS2_6targetE2014ELNS2_11placeholderE0EE13__set_pointerEPU3AS1i(%"class.cl::sycl::accessor"* %1, i32 addrspace(1)* %2)
21-
// CHECK: call spir_func void @"_ZZZ4mainENK3$_0clERN2cl4sycl7handlerEENKUlNS1_2idILm1EEEE_clES5_"(%class.anon* %0, %"struct.cl::sycl::id"* byval align 8 %wiID)
22-
// CHECK: %call = call spir_func i64 @_Z13get_global_idj(i32 0)
20+
// CHECK: %call = call spir_func i64 @_Z13get_global_idj(i32 %{{.*}})
21+
// CHECK: call spir_func void @"_ZZZ4mainENK3$_0clERN2cl4sycl7handlerEENKUlNS1_2idILm1EEEE_clES5_"(%class.anon* %0, %"struct.cl::sycl::id"* byval align 8 %{{.*}})
2322
cgh.parallel_for<class kernel_function>(numOfItems,
2423
[=](cl::sycl::id<1> wiID) {
2524
accessorA[wiID] = accessorA[wiID] * accessorA[wiID];

0 commit comments

Comments
 (0)