12
12
#include " clang/AST/AST.h"
13
13
#include " clang/Sema/Sema.h"
14
14
#include " llvm/ADT/SmallVector.h"
15
+ #include " TreeTransform.h"
15
16
16
17
using namespace clang ;
17
18
18
- LambdaExpr *getBodyAsLambda (CXXMemberCallExpr *e) {
19
- auto LastArg = e->getArg (e->getNumArgs () - 1 );
20
- return dyn_cast<LambdaExpr>(LastArg);
19
+ typedef llvm::DenseMap<DeclaratorDecl *, DeclaratorDecl *> DeclMap;
20
+
21
+ class KernelBodyTransform : public TreeTransform <KernelBodyTransform> {
22
+ public:
23
+ KernelBodyTransform (llvm::DenseMap<DeclaratorDecl *, DeclaratorDecl *> &Map,
24
+ Sema &S)
25
+ : TreeTransform<KernelBodyTransform>(S), DMap(Map), SemaRef(S) {}
26
+ bool AlwaysRebuild () { return true ; }
27
+
28
+ ExprResult TransformDeclRefExpr (DeclRefExpr *DRE) {
29
+ auto Ref = dyn_cast<DeclaratorDecl>(DRE->getDecl ());
30
+ if (Ref) {
31
+ auto NewDecl = DMap[Ref];
32
+ if (NewDecl) {
33
+ return DeclRefExpr::Create (
34
+ SemaRef.getASTContext (), DRE->getQualifierLoc (),
35
+ DRE->getTemplateKeywordLoc (), NewDecl, false , DRE->getNameInfo (),
36
+ NewDecl->getType (), DRE->getValueKind ());
37
+ }
38
+ }
39
+ return DRE;
40
+ }
41
+
42
+ private:
43
+ DeclMap DMap;
44
+ Sema &SemaRef;
45
+ };
46
+
47
+ CXXRecordDecl* getBodyAsLambda (FunctionDecl *FD) {
48
+ auto FirstArg = (*FD->param_begin ());
49
+ if (FirstArg)
50
+ if (FirstArg->getType ()->getAsCXXRecordDecl ()->isLambda ())
51
+ return FirstArg->getType ()->getAsCXXRecordDecl ();
52
+ return nullptr ;
21
53
}
22
54
23
55
FunctionDecl *CreateSYCLKernelFunction (ASTContext &Context, StringRef Name,
@@ -54,17 +86,16 @@ FunctionDecl *CreateSYCLKernelFunction(ASTContext &Context, StringRef Name,
54
86
return Result;
55
87
}
56
88
57
- CompoundStmt *CreateSYCLKernelBody (Sema &S, CXXMemberCallExpr *e ,
89
+ CompoundStmt *CreateSYCLKernelBody (Sema &S, FunctionDecl *KernelHelper ,
58
90
DeclContext *DC) {
59
91
60
92
llvm::SmallVector<Stmt *, 16 > BodyStmts;
61
93
62
94
// TODO: case when kernel is functor
63
95
// TODO: possible refactoring when functor case will be completed
64
- LambdaExpr *LE = getBodyAsLambda (e );
65
- if (LE ) {
96
+ CXXRecordDecl *LC = getBodyAsLambda (KernelHelper );
97
+ if (LC ) {
66
98
// Create Lambda object
67
- CXXRecordDecl *LC = LE->getLambdaClass ();
68
99
auto LambdaVD = VarDecl::Create (
69
100
S.Context , DC, SourceLocation (), SourceLocation (), LC->getIdentifier (),
70
101
QualType (LC->getTypeForDecl (), 0 ), LC->getLambdaTypeInfo (), SC_None);
@@ -137,43 +168,23 @@ CompoundStmt *CreateSYCLKernelBody(Sema &S, CXXMemberCallExpr *e,
137
168
TargetFuncParam++;
138
169
}
139
170
140
- // Create Lambda operator () call
141
- FunctionDecl *LO = LE->getCallOperator ();
142
- ArrayRef<ParmVarDecl *> Args = LO->parameters ();
143
- llvm::SmallVector<Expr *, 16 > ParamStmts (1 );
144
- ParamStmts[0 ] = dyn_cast<Expr>(LambdaDRE);
145
-
146
- // Collect arguments for () operator
147
- for (auto Arg : Args) {
148
- QualType ArgType = Arg->getOriginalType ();
149
- // Declare variable for parameter and pass it to call
150
- auto param_VD =
151
- VarDecl::Create (S.Context , DC, SourceLocation (), SourceLocation (),
152
- Arg->getIdentifier (), ArgType,
153
- S.Context .getTrivialTypeSourceInfo (ArgType), SC_None);
154
- Stmt *param_DS = new (S.Context )
155
- DeclStmt (DeclGroupRef (param_VD), SourceLocation (), SourceLocation ());
156
- BodyStmts.push_back (param_DS);
157
- auto DRE = DeclRefExpr::Create (S.Context , NestedNameSpecifierLoc (),
158
- SourceLocation (), param_VD, false ,
159
- DeclarationNameInfo (), ArgType, VK_LValue);
160
- Expr *Res = ImplicitCastExpr::Create (
161
- S.Context , ArgType, CK_LValueToRValue, DRE, nullptr , VK_RValue);
162
- ParamStmts.push_back (Res);
163
- }
171
+ // In function from headers lambda is function parameter, we need
172
+ // to replace all refs to this lambda with our vardecl.
173
+ // I used TreeTransform here, but I'm not sure that it is good solution
174
+ // Also I used map and I'm not sure about it too.
175
+ Stmt* FunctionBody = KernelHelper->getBody ();
176
+ DeclMap DMap;
177
+ ParmVarDecl* LambdaParam = *(KernelHelper->param_begin ());
178
+ // DeclRefExpr with valid source location but with decl which is not marked
179
+ // as used is invalid.
180
+ LambdaVD->setIsUsed ();
181
+ DMap[LambdaParam] = LambdaVD;
182
+ // Without PushFunctionScope I had segfault. Maybe we also need to do pop.
183
+ S.PushFunctionScope ();
184
+ KernelBodyTransform KBT (DMap, S);
185
+ Stmt* NewBody = KBT.TransformStmt (FunctionBody).get ();
186
+ BodyStmts.push_back (NewBody);
164
187
165
- // Create ref for call operator
166
- DeclRefExpr *DRE = new (S.Context )
167
- DeclRefExpr (S.Context , LO, false , LO->getType (), VK_LValue,
168
- SourceLocation ());
169
- QualType ResultTy = LO->getReturnType ();
170
- ExprValueKind VK = Expr::getValueKindForType (ResultTy);
171
- ResultTy = ResultTy.getNonLValueExprType (S.Context );
172
-
173
- CXXOperatorCallExpr *TheCall = CXXOperatorCallExpr::Create (
174
- S.Context , OO_Call, DRE, ParamStmts, ResultTy, VK, SourceLocation (),
175
- FPOptions (), clang::CallExpr::ADLCallKind::NotADL );
176
- BodyStmts.push_back (TheCall);
177
188
}
178
189
return CompoundStmt::Create (S.Context , BodyStmts, SourceLocation (),
179
190
SourceLocation ());
@@ -222,9 +233,9 @@ void BuildArgTys(ASTContext &Context,
222
233
}
223
234
}
224
235
225
- void Sema::ConstructSYCLKernel (CXXMemberCallExpr *e ) {
236
+ void Sema::ConstructSYCLKernel (FunctionDecl *KernelHelper ) {
226
237
// TODO: Case when kernel is functor
227
- LambdaExpr *LE = getBodyAsLambda (e );
238
+ CXXRecordDecl *LE = getBodyAsLambda (KernelHelper );
228
239
if (LE) {
229
240
230
241
llvm::SmallVector<DeclaratorDecl *, 16 > ArgDecls;
@@ -238,9 +249,8 @@ void Sema::ConstructSYCLKernel(CXXMemberCallExpr *e) {
238
249
BuildArgTys (getASTContext (), ArgDecls, NewArgDecls, ArgTys);
239
250
240
251
// Get Name for our kernel.
241
- FunctionDecl *FuncDecl = e->getMethodDecl ();
242
252
const TemplateArgumentList *TemplateArgs =
243
- FuncDecl ->getTemplateSpecializationArgs ();
253
+ KernelHelper ->getTemplateSpecializationArgs ();
244
254
QualType KernelNameType = TemplateArgs->get (0 ).getAsType ();
245
255
std::string Name = KernelNameType.getBaseTypeIdentifier ()->getName ().str ();
246
256
@@ -256,7 +266,7 @@ void Sema::ConstructSYCLKernel(CXXMemberCallExpr *e) {
256
266
FunctionDecl *SYCLKernel =
257
267
CreateSYCLKernelFunction (getASTContext (), Name, ArgTys, NewArgDecls);
258
268
259
- CompoundStmt *SYCLKernelBody = CreateSYCLKernelBody (*this , e , SYCLKernel);
269
+ CompoundStmt *SYCLKernelBody = CreateSYCLKernelBody (*this , KernelHelper , SYCLKernel);
260
270
SYCLKernel->setBody (SYCLKernelBody);
261
271
262
272
AddSyclKernel (SYCLKernel);
0 commit comments