14
14
#include " clang/AST/RecordLayout.h"
15
15
#include " clang/AST/RecursiveASTVisitor.h"
16
16
#include " clang/Sema/Sema.h"
17
- #include " llvm/ADT/SmallVector.h"
18
17
#include " llvm/ADT/SmallPtrSet.h"
18
+ #include " llvm/ADT/SmallVector.h"
19
19
#include " llvm/Support/FileSystem.h"
20
20
#include " llvm/Support/Path.h"
21
21
#include " llvm/Support/raw_ostream.h"
@@ -154,14 +154,28 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
154
154
S.Context , NestedNameSpecifierLoc (), SourceLocation (), LambdaVD, false ,
155
155
DeclarationNameInfo (), QualType (LC->getTypeForDecl (), 0 ), VK_LValue);
156
156
157
- // Initialize Lambda fields
158
- llvm::SmallVector<Expr *, 16 > InitCaptures;
159
-
160
157
auto TargetFunc = dyn_cast<FunctionDecl>(DC);
161
158
auto TargetFuncParam =
162
159
TargetFunc->param_begin (); // Iterator to ParamVarDecl (VarDecl)
163
160
if (TargetFuncParam) {
164
161
for (auto Field : LC->fields ()) {
162
+ auto getExprForPointer = [](Sema &S, const QualType ¶mTy,
163
+ DeclRefExpr *DRE) {
164
+ // C++ address space attribute != OpenCL address space attribute
165
+ Expr *qualifiersCast = ImplicitCastExpr::Create (
166
+ S.Context , paramTy, CK_NoOp, DRE, nullptr , VK_LValue);
167
+ Expr *Res =
168
+ ImplicitCastExpr::Create (S.Context , paramTy, CK_LValueToRValue,
169
+ qualifiersCast, nullptr , VK_RValue);
170
+ return Res;
171
+ };
172
+ auto getExprForRange = [](Sema &S, const QualType ¶mTy,
173
+ DeclRefExpr *DRE) {
174
+ Expr *Res = ImplicitCastExpr::Create (S.Context , paramTy, CK_NoOp, DRE,
175
+ nullptr , VK_RValue);
176
+ return Res;
177
+ };
178
+
165
179
QualType ParamType = (*TargetFuncParam)->getOriginalType ();
166
180
auto DRE =
167
181
DeclRefExpr::Create (S.Context , NestedNameSpecifierLoc (),
@@ -171,18 +185,20 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
171
185
QualType FieldType = Field->getType ();
172
186
CXXRecordDecl *CRD = FieldType->getAsCXXRecordDecl ();
173
187
if (CRD) {
174
- llvm::SmallVector<Expr *, 16 > ParamStmts;
175
188
DeclAccessPair FieldDAP = DeclAccessPair::make (Field, AS_none);
189
+ // lambda.accessor
176
190
auto AccessorME = MemberExpr::Create (
177
191
S.Context , LambdaDRE, false , SourceLocation (),
178
192
NestedNameSpecifierLoc (), SourceLocation (), Field, FieldDAP,
179
193
DeclarationNameInfo (Field->getDeclName (), SourceLocation ()),
180
194
nullptr , Field->getType (), VK_LValue, OK_Ordinary);
181
-
195
+ bool PointerOfAccesorWasSet = false ;
182
196
for (auto Method : CRD->methods ()) {
197
+ llvm::SmallVector<Expr *, 16 > ParamStmts;
183
198
if (Method->getNameInfo ().getName ().getAsString () ==
184
199
" __set_pointer" ) {
185
200
DeclAccessPair MethodDAP = DeclAccessPair::make (Method, AS_none);
201
+ // lambda.accessor.__set_pointer
186
202
auto ME = MemberExpr::Create (
187
203
S.Context , AccessorME, false , SourceLocation (),
188
204
NestedNameSpecifierLoc (), SourceLocation (), Method, MethodDAP,
@@ -199,19 +215,75 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
199
215
// __set_pointer needs one parameter
200
216
QualType paramTy = (*(Method->param_begin ()))->getOriginalType ();
201
217
202
- // C++ address space attribute != OpenCL address space attribute
203
- Expr *qualifiersCast = ImplicitCastExpr::Create (
204
- S.Context , paramTy, CK_NoOp, DRE, nullptr , VK_LValue);
205
- Expr *Res = ImplicitCastExpr::Create (
206
- S.Context , paramTy, CK_LValueToRValue, qualifiersCast,
207
- nullptr , VK_RValue);
218
+ Expr *Res = getExprForPointer (S, paramTy, DRE);
208
219
220
+ // kernel_parameter
209
221
ParamStmts.push_back (Res);
210
-
211
222
// lambda.accessor.__set_pointer(kernel_parameter)
212
223
CXXMemberCallExpr *Call = CXXMemberCallExpr::Create (
213
224
S.Context , ME, ParamStmts, ResultTy, VK, SourceLocation ());
214
225
BodyStmts.push_back (Call);
226
+ PointerOfAccesorWasSet = true ;
227
+ }
228
+ }
229
+ if (PointerOfAccesorWasSet) {
230
+ TargetFuncParam++;
231
+
232
+ ParamType = (*TargetFuncParam)->getOriginalType ();
233
+ DRE = DeclRefExpr::Create (S.Context , NestedNameSpecifierLoc (),
234
+ SourceLocation (), *TargetFuncParam, false ,
235
+ DeclarationNameInfo (), ParamType,
236
+ VK_LValue);
237
+
238
+ FieldType = Field->getType ();
239
+ CRD = FieldType->getAsCXXRecordDecl ();
240
+ if (CRD) {
241
+ FieldDAP = DeclAccessPair::make (Field, AS_none);
242
+ // lambda.accessor
243
+ AccessorME = MemberExpr::Create (
244
+ S.Context , LambdaDRE, false , SourceLocation (),
245
+ NestedNameSpecifierLoc (), SourceLocation (), Field, FieldDAP,
246
+ DeclarationNameInfo (Field->getDeclName (), SourceLocation ()),
247
+ nullptr , Field->getType (), VK_LValue, OK_Ordinary);
248
+
249
+ for (auto Method : CRD->methods ()) {
250
+ llvm::SmallVector<Expr *, 16 > ParamStmts;
251
+ if (Method->getNameInfo ().getName ().getAsString () ==
252
+ " __set_range" ) {
253
+ // lambda.accessor.__set_range
254
+ DeclAccessPair MethodDAP =
255
+ DeclAccessPair::make (Method, AS_none);
256
+ auto ME = MemberExpr::Create (
257
+ S.Context , AccessorME, false , SourceLocation (),
258
+ NestedNameSpecifierLoc (), SourceLocation (), Method,
259
+ MethodDAP, Method->getNameInfo (), nullptr ,
260
+ Method->getType (), VK_LValue, OK_Ordinary);
261
+
262
+ // Not referenced -> not emitted
263
+ S.MarkFunctionReferenced (SourceLocation (), Method, true );
264
+
265
+ QualType ResultTy = Method->getReturnType ();
266
+ ExprValueKind VK = Expr::getValueKindForType (ResultTy);
267
+ ResultTy = ResultTy.getNonLValueExprType (S.Context );
268
+
269
+ // __set_range needs one parameter
270
+ QualType paramTy =
271
+ (*(Method->param_begin ()))->getOriginalType ();
272
+
273
+ Expr *Res = getExprForRange (S, paramTy, DRE);
274
+
275
+ // kernel_parameter
276
+ ParamStmts.push_back (Res);
277
+ // lambda.accessor.__set_range(kernel_parameter)
278
+ CXXMemberCallExpr *Call = CXXMemberCallExpr::Create (
279
+ S.Context , ME, ParamStmts, ResultTy, VK,
280
+ SourceLocation ());
281
+ BodyStmts.push_back (Call);
282
+ }
283
+ }
284
+ } else {
285
+ llvm_unreachable (
286
+ " unsupported accessor and without initialized range" );
215
287
}
216
288
}
217
289
} else if (FieldType->isBuiltinType ()) {
@@ -279,6 +351,7 @@ class Util {
279
351
// / invocation.
280
352
enum VisitorContext {
281
353
pre_visit,
354
+ pre_visit_class_field,
282
355
visit_accessor,
283
356
visit_scalar,
284
357
visit_stream,
@@ -308,7 +381,7 @@ static void visitKernelLambdaCaptures(const CXXRecordDecl *Lambda,
308
381
QualType ArgTy = V->getType ();
309
382
auto F1 = std::get<pre_visit>(Vis);
310
383
F1 (Cnt, V, *Fld);
311
-
384
+ FieldDecl *AccessorRangeField = nullptr ;
312
385
if (Util::isSyclAccessorType (ArgTy)) {
313
386
// accessor parameter context
314
387
const auto *RecordDecl = ArgTy->getAsCXXRecordDecl ();
@@ -317,6 +390,26 @@ static void visitKernelLambdaCaptures(const CXXRecordDecl *Lambda,
317
390
dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl);
318
391
assert (TemplateDecl && " templated accessor type expected" );
319
392
393
+ auto getFieldByName = [](const CXXRecordDecl *RecordDecl,
394
+ std::string Name) {
395
+ FieldDecl *result = nullptr ;
396
+ for (auto jt = RecordDecl->field_begin (); jt != RecordDecl->field_end ();
397
+ ++jt) {
398
+ if (jt->getNameAsString () == Name) {
399
+ result = *jt;
400
+ break ;
401
+ }
402
+ }
403
+ return result;
404
+ };
405
+ FieldDecl *AccessorImplField = getFieldByName (RecordDecl, " __impl" );
406
+ assert (AccessorImplField && " no __impl found in accessor" );
407
+ const auto *AccessorImplRecord =
408
+ AccessorImplField->getType ()->getAsCXXRecordDecl ();
409
+ assert (AccessorImplRecord && " accessor __impl must be of a record type" );
410
+ AccessorRangeField = getFieldByName (AccessorImplRecord, " Range" );
411
+ assert (AccessorRangeField && " no Range found in __impl of accessor" );
412
+
320
413
// First accessor template parameter - data type
321
414
QualType PointeeType = TemplateDecl->getTemplateArgs ()[0 ].getAsType ();
322
415
// Fourth parameter - access target
@@ -335,9 +428,20 @@ static void visitKernelLambdaCaptures(const CXXRecordDecl *Lambda,
335
428
} else {
336
429
llvm_unreachable (" unsupported kernel parameter type" );
337
430
}
338
- // pos -visit context
431
+ // post -visit context
339
432
auto F2 = std::get<post_visit>(Vis);
340
433
F2 (Cnt, V, *Fld);
434
+
435
+ if (AccessorRangeField) {
436
+ // pre-visit context the same like for accessor
437
+ auto F1Range = std::get<pre_visit_class_field>(Vis);
438
+ F1Range (Cnt, V, *Fld, AccessorRangeField);
439
+ auto FRange = std::get<visit_scalar>(Vis);
440
+ FRange (Cnt, V, AccessorRangeField);
441
+ // post-visit context
442
+ auto F2Range = std::get<post_visit>(Vis);
443
+ F2Range (Cnt, nullptr , AccessorRangeField);
444
+ }
341
445
}
342
446
assert ((Cpt == CptEnd) && (Fld == FldEnd) &&
343
447
" captures inconsistent with fields" );
@@ -350,6 +454,8 @@ static void BuildArgTys(ASTContext &Context, CXXRecordDecl *Lambda,
350
454
auto Vis = std::make_tuple (
351
455
// pre_visit
352
456
[&](int , VarDecl *, FieldDecl *) {},
457
+ // pre_visit_class_field
458
+ [&](int , VarDecl *, FieldDecl *, FieldDecl *) {},
353
459
// visit_accessor
354
460
[&](int CaptureN, target AccTrg, QualType PointeeType,
355
461
DeclaratorDecl *CapturedVar, FieldDecl *CapturedVal) {
@@ -390,7 +496,10 @@ static void BuildArgTys(ASTContext &Context, CXXRecordDecl *Lambda,
390
496
IdentifierInfo *VarName = 0 ;
391
497
SmallString<8 > Str;
392
498
llvm::raw_svector_ostream OS (Str);
393
- OS << " _arg_" << CapturedVar->getIdentifier ()->getName ();
499
+ IdentifierInfo *Identifier = (CapturedVar != nullptr )
500
+ ? CapturedVar->getIdentifier ()
501
+ : CapturedVal->getIdentifier ();
502
+ OS << " _arg_" << Identifier->getName ();
394
503
VarName = &Context.Idents .get (OS.str ());
395
504
396
505
auto NewVarDecl = VarDecl::Create (
@@ -422,7 +531,24 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
422
531
[&](int CaptureN, VarDecl *CapturedVar, FieldDecl *CapturedVal) {
423
532
// Set offset in bytes
424
533
Offset = static_cast <unsigned >(
425
- Layout.getFieldOffset (CapturedVal->getFieldIndex ()))/8 ;
534
+ Layout.getFieldOffset (CapturedVal->getFieldIndex ())) /
535
+ 8 ;
536
+ },
537
+ // pre_visit_class_field
538
+ [&](int CaptureN, VarDecl *CapturedVar, FieldDecl *CapturedVal,
539
+ FieldDecl *MemberVal) {
540
+ // Set offset of parent in bytes
541
+ Offset = static_cast <unsigned >(
542
+ Layout.getFieldOffset (CapturedVal->getFieldIndex ())) /
543
+ 8 ;
544
+ const RecordDecl *parent = MemberVal->getParent ();
545
+ ASTContext &CtxMember = parent->getASTContext ();
546
+ const ASTRecordLayout &LayoutMember =
547
+ CtxMember.getASTRecordLayout (parent);
548
+ // Add offset relative to parent in bytes
549
+ Offset += static_cast <unsigned >(
550
+ LayoutMember.getFieldOffset (MemberVal->getFieldIndex ())) /
551
+ 8 ;
426
552
},
427
553
// visit_accessor
428
554
[&](int CaptureN, target AccTrg, QualType PointeeType,
@@ -453,15 +579,12 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
453
579
// are removed to make the name shorter. Non-alphanumeric characters in a kernel
454
580
// name are OK - SPIRV and runtimes allow that.
455
581
static std::string constructKernelName (QualType KernelNameType) {
456
- static const std::string Kwds[] = {
457
- std::string (" class" ),
458
- std::string (" struct" )
459
- };
582
+ static const std::string Kwds[] = {std::string (" class" ),
583
+ std::string (" struct" )};
460
584
std::string TStr = KernelNameType.getAsString ();
461
585
462
586
for (const std::string &Kwd : Kwds) {
463
- for (size_t Pos = TStr.find (Kwd);
464
- Pos != StringRef::npos;
587
+ for (size_t Pos = TStr.find (Kwd); Pos != StringRef::npos;
465
588
Pos = TStr.find (Kwd, Pos)) {
466
589
467
590
size_t EndPos = Pos + Kwd.length ();
@@ -593,12 +716,13 @@ static void printDecl(raw_ostream &O, const Decl *D) {
593
716
// \param Depth
594
717
// recursion depth
595
718
//
596
- static void emitForwardClassDecls (raw_ostream &O,
597
- QualType T,
598
- llvm::SmallPtrSetImpl<const void *> &Printed) {
719
+ static void
720
+ emitForwardClassDecls (raw_ostream &O, QualType T,
721
+ llvm::SmallPtrSetImpl<const void *> &Printed) {
599
722
600
723
// peel off the pointer types and get the class/struct type:
601
- for (; T->isPointerType (); T = T->getPointeeType ());
724
+ for (; T->isPointerType (); T = T->getPointeeType ())
725
+ ;
602
726
const CXXRecordDecl *RD = T->getAsCXXRecordDecl ();
603
727
604
728
if (!RD)
@@ -657,7 +781,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
657
781
O << " \n " ;
658
782
O << " // Forward declarations of templated kernel function types:\n " ;
659
783
660
- llvm::SmallPtrSet<const void *, 4 > Printed;
784
+ llvm::SmallPtrSet<const void *, 4 > Printed;
661
785
662
786
for (const KernelDesc &K : KernelDescs) {
663
787
emitForwardClassDecls (O, K.NameType , Printed);
@@ -737,12 +861,11 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
737
861
738
862
for (const KernelDesc &K : KernelDescs) {
739
863
const size_t N = K.Params .size ();
740
- O << " template <> struct KernelInfo<" <<
741
- K.NameType .getAsString () << " > {\n " ;
742
- O << " static constexpr const char* getName() { return \" "
743
- << K.Name << " \" ; }\n " ;
744
- O << " static constexpr unsigned getNumParams() { return "
745
- << N << " ; }\n " ;
864
+ O << " template <> struct KernelInfo<" << K.NameType .getAsString ()
865
+ << " > {\n " ;
866
+ O << " static constexpr const char* getName() { return \" " << K.Name
867
+ << " \" ; }\n " ;
868
+ O << " static constexpr unsigned getNumParams() { return " << N << " ; }\n " ;
746
869
O << " static constexpr const kernel_param_desc_t& " ;
747
870
O << " getParamDesc(unsigned i) {\n " ;
748
871
O << " return kernel_signatures[i+" << CurStart << " ];\n " ;
@@ -757,7 +880,6 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
757
880
O << " \n " ;
758
881
}
759
882
760
-
761
883
bool SYCLIntegrationHeader::emit (const StringRef &IntHeaderName) {
762
884
if (IntHeaderName.empty ())
763
885
return false ;
0 commit comments