@@ -688,22 +688,6 @@ static ParamDesc makeParamDesc(ASTContext &Ctx, const CXXBaseSpecifier &Src,
688
688
Ctx.getTrivialTypeSourceInfo (Ty));
689
689
}
690
690
691
- // Create a new class around a field - used to wrap arrays.
692
- static RecordDecl *wrapAnArray (ASTContext &Ctx, const QualType ArgTy,
693
- FieldDecl *&Field) {
694
- RecordDecl *NewClass = Ctx.buildImplicitRecord (" wrapped_array" );
695
- NewClass->startDefinition ();
696
- Field = FieldDecl::Create (
697
- Ctx, NewClass, SourceLocation (), SourceLocation (),
698
- /* Id=*/ nullptr , ArgTy,
699
- Ctx.getTrivialTypeSourceInfo (ArgTy, SourceLocation ()),
700
- /* BW=*/ nullptr , /* Mutable=*/ false , /* InitStyle=*/ ICIS_NoInit);
701
- Field->setAccess (AS_public);
702
- NewClass->addDecl (Field);
703
- NewClass->completeDefinition ();
704
- return NewClass;
705
- }
706
-
707
691
// / \return the target of given SYCL accessor type
708
692
static target getAccessTarget (const ClassTemplateSpecializationDecl *AccTy) {
709
693
return static_cast <target>(
@@ -799,15 +783,21 @@ static void VisitField(CXXRecordDecl *Owner, RangeTy &&Item, QualType ItemTy,
799
783
Handlers &... handlers) {
800
784
if (Util::isSyclAccessorType (ItemTy))
801
785
KF_FOR_EACH (handleSyclAccessorType, Item, ItemTy);
802
- if (Util::isSyclStreamType (ItemTy))
786
+ else if (Util::isSyclStreamType (ItemTy))
803
787
KF_FOR_EACH (handleSyclStreamType, Item, ItemTy);
804
- if (ItemTy->isStructureOrClassType ())
788
+ else if (ItemTy->isStructureOrClassType ())
805
789
VisitAccessorWrapper (Owner, Item, ItemTy->getAsCXXRecordDecl (),
806
790
handlers...);
807
- if (ItemTy->isArrayType ())
791
+ else if (ItemTy->isArrayType ())
808
792
VisitArrayElements (Item, ItemTy, handlers...);
809
793
}
810
794
795
+ template <typename RangeTy, typename ... Handlers>
796
+ static void VisitScalarField (CXXRecordDecl *Owner, RangeTy &&Item, QualType ItemTy,
797
+ Handlers &... handlers) {
798
+ KF_FOR_EACH (handleScalarType, Item, ItemTy);
799
+ }
800
+
811
801
template <typename RangeTy, typename ... Handlers>
812
802
static void VisitArrayElements (RangeTy Item, QualType FieldTy,
813
803
Handlers &... handlers) {
@@ -816,7 +806,10 @@ static void VisitArrayElements(RangeTy Item, QualType FieldTy,
816
806
int64_t ElemCount = CAT->getSize ().getSExtValue ();
817
807
std::initializer_list<int >{(handlers.enterArray (), 0 )...};
818
808
for (int64_t Count = 0 ; Count < ElemCount; Count++) {
819
- VisitField (nullptr , Item, ET, handlers...);
809
+ if (ET->isScalarType ())
810
+ VisitScalarField (nullptr , Item, ET, handlers...);
811
+ else
812
+ VisitField (nullptr , Item, ET, handlers...);
820
813
(void )std::initializer_list<int >{(handlers.nextElement (ET), 0 )...};
821
814
}
822
815
(void )std::initializer_list<int >{(handlers.leaveArray (ET, ElemCount), 0 )...};
@@ -919,6 +912,9 @@ template <typename Derived> class SyclKernelFieldHandler {
919
912
virtual bool handleReferenceType (FieldDecl *, QualType) { return true ; }
920
913
virtual bool handlePointerType (FieldDecl *, QualType) { return true ; }
921
914
virtual bool handleArrayType (FieldDecl *, QualType) { return true ; }
915
+ virtual bool handleScalarType (const CXXBaseSpecifier &, QualType) {
916
+ return true ;
917
+ }
922
918
virtual bool handleScalarType (FieldDecl *, QualType) { return true ; }
923
919
// Most handlers shouldn't be handling this, just the field checker.
924
920
virtual bool handleOtherType (FieldDecl *, QualType) { return true ; }
@@ -1003,7 +999,8 @@ class SyclKernelFieldChecker
1003
999
1004
1000
public:
1005
1001
SyclKernelFieldChecker (Sema &S)
1006
- : SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {}
1002
+ : SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {
1003
+ }
1007
1004
bool isValid () { return !IsInvalid; }
1008
1005
1009
1006
bool handleReferenceType (FieldDecl *FD, QualType FieldTy) final {
@@ -1052,6 +1049,10 @@ class SyclKernelDeclCreator
1052
1049
size_t LastParamIndex = 0 ;
1053
1050
1054
1051
void addParam (const FieldDecl *FD, QualType FieldTy) {
1052
+ const ConstantArrayType *CAT =
1053
+ SemaRef.getASTContext ().getAsConstantArrayType (FieldTy);
1054
+ if (CAT)
1055
+ FieldTy = CAT->getElementType ();
1055
1056
ParamDesc newParamDesc = makeParamDesc (FD, FieldTy);
1056
1057
addParam (newParamDesc, FieldTy);
1057
1058
}
@@ -1068,7 +1069,6 @@ class SyclKernelDeclCreator
1068
1069
SemaRef.getASTContext (), KernelDecl, SourceLocation (), SourceLocation (),
1069
1070
std::get<1 >(newParamDesc), std::get<0 >(newParamDesc),
1070
1071
std::get<2 >(newParamDesc), SC_None, /* DefArg*/ nullptr );
1071
-
1072
1072
NewParam->setScopeInfo (0 , Params.size ());
1073
1073
NewParam->setIsUsed ();
1074
1074
@@ -1131,7 +1131,8 @@ class SyclKernelDeclCreator
1131
1131
: SyclKernelFieldHandler(S),
1132
1132
KernelDecl (createKernelDecl(S.getASTContext(), Name, Loc, IsInline,
1133
1133
IsSIMDKernel)),
1134
- ArgChecker(ArgChecker), FuncContext(SemaRef, KernelDecl) {}
1134
+ ArgChecker(ArgChecker), FuncContext(SemaRef, KernelDecl) {
1135
+ }
1135
1136
1136
1137
~SyclKernelDeclCreator () {
1137
1138
ASTContext &Ctx = SemaRef.getASTContext ();
@@ -1189,18 +1190,12 @@ class SyclKernelDeclCreator
1189
1190
return true ;
1190
1191
}
1191
1192
1192
- bool handleArrayType (FieldDecl *FD, QualType FieldTy) final {
1193
- RecordDecl *NewClass = wrapAnArray (SemaRef.getASTContext (), FieldTy, FD);
1194
- QualType ST = SemaRef.getASTContext ().getRecordType (NewClass);
1195
- addParam (FD, ST);
1196
- return true ;
1197
- }
1198
-
1199
1193
bool handleScalarType (FieldDecl *FD, QualType FieldTy) final {
1200
1194
addParam (FD, FieldTy);
1201
1195
return true ;
1202
1196
}
1203
1197
1198
+ // FIXME Remove this function when structs are replaced by their fields
1204
1199
bool handleStructType (FieldDecl *FD, QualType FieldTy) final {
1205
1200
addParam (FD, FieldTy);
1206
1201
return true ;
@@ -1225,6 +1220,8 @@ class SyclKernelDeclCreator
1225
1220
return ArrayRef<ParmVarDecl *>(std::begin (Params) + LastParamIndex,
1226
1221
std::end (Params));
1227
1222
}
1223
+
1224
+ using SyclKernelFieldHandler::handleScalarType;
1228
1225
};
1229
1226
1230
1227
class SyclKernelBodyCreator
@@ -1309,11 +1306,9 @@ class SyclKernelBodyCreator
1309
1306
return Result;
1310
1307
}
1311
1308
1312
- void createExprForStructOrScalar (FieldDecl *FD) {
1309
+ Expr * createInitExpr (FieldDecl *FD) {
1313
1310
ParmVarDecl *KernelParameter =
1314
1311
DeclCreator.getParamVarDeclsForCurrentField ()[0 ];
1315
- InitializedEntity Entity =
1316
- InitializedEntity::InitializeMember (FD, &VarEntity);
1317
1312
QualType ParamType = KernelParameter->getOriginalType ();
1318
1313
Expr *DRE = SemaRef.BuildDeclRefExpr (KernelParameter, ParamType, VK_LValue,
1319
1314
SourceLocation ());
@@ -1323,32 +1318,49 @@ class SyclKernelBodyCreator
1323
1318
DRE = ImplicitCastExpr::Create (SemaRef.Context , FD->getType (),
1324
1319
CK_AddressSpaceConversion, DRE, nullptr ,
1325
1320
VK_RValue);
1321
+ return DRE;
1322
+ }
1323
+
1324
+ void createExprForStructOrScalar (FieldDecl *FD) {
1325
+ InitializedEntity Entity =
1326
+ InitializedEntity::InitializeMember (FD, &VarEntity);
1326
1327
InitializationKind InitKind =
1327
1328
InitializationKind::CreateCopy (SourceLocation (), SourceLocation ());
1329
+ Expr *DRE = createInitExpr (FD);
1328
1330
InitializationSequence InitSeq (SemaRef, Entity, InitKind, DRE);
1329
-
1330
1331
ExprResult MemberInit = InitSeq.Perform (SemaRef, Entity, InitKind, DRE);
1331
1332
InitExprs.push_back (MemberInit.get ());
1332
1333
}
1333
1334
1334
- void createExprForArray (FieldDecl *FD) {
1335
- ParmVarDecl *KernelParameter =
1336
- DeclCreator.getParamVarDeclsForCurrentField ()[0 ];
1337
- QualType ParamType = KernelParameter->getOriginalType ();
1338
- CXXRecordDecl *WrapperStruct = ParamType->getAsCXXRecordDecl ();
1339
- // The first and only field of the wrapper struct is the array
1340
- FieldDecl *Array = *(WrapperStruct->field_begin ());
1341
- Expr *DRE = SemaRef.BuildDeclRefExpr (KernelParameter, ParamType, VK_LValue,
1342
- SourceLocation ());
1343
- Expr *InitExpr = BuildMemberExpr (DRE, Array);
1344
- InitializationKind InitKind = InitializationKind::CreateDirect (
1345
- SourceLocation (), SourceLocation (), SourceLocation ());
1346
- InitializedEntity Entity = InitializedEntity::InitializeLambdaCapture (
1347
- nullptr , Array->getType (), SourceLocation ());
1348
- InitializationSequence InitSeq (SemaRef, Entity, InitKind, InitExpr);
1349
- ExprResult MemberInit =
1350
- InitSeq.Perform (SemaRef, Entity, InitKind, InitExpr);
1351
- InitExprs.push_back (MemberInit.get ());
1335
+ void createExprForScalarElement (FieldDecl *FD, QualType FieldTy) {
1336
+ InitializedEntity ArrayEntity =
1337
+ InitializedEntity::InitializeMember (FD, &VarEntity);
1338
+ InitializationKind InitKind =
1339
+ InitializationKind::CreateCopy (SourceLocation (), SourceLocation ());
1340
+ Expr *DRE = createInitExpr (FD);
1341
+ Expr *Idx = dyn_cast<ArraySubscriptExpr>(MemberExprBases.back ())->getIdx ();
1342
+ llvm::APSInt Result;
1343
+ SemaRef.VerifyIntegerConstantExpression (Idx, &Result);
1344
+ uint64_t IntIdx = Result.getZExtValue ();
1345
+ InitializedEntity Entity = InitializedEntity::InitializeElement (
1346
+ SemaRef.getASTContext (), IntIdx, ArrayEntity);
1347
+ InitializationSequence InitSeq (SemaRef, Entity, InitKind, DRE);
1348
+ ExprResult MemberInit = InitSeq.Perform (SemaRef, Entity, InitKind, DRE);
1349
+ llvm::SmallVector<Expr *, 16 > ArrayInitExprs;
1350
+ if (IntIdx > 0 ) {
1351
+ // Continue with the current InitList
1352
+ InitListExpr *ILE = cast<InitListExpr>(InitExprs.back ());
1353
+ InitExprs.pop_back ();
1354
+ llvm::ArrayRef<Expr *> L = ILE->inits ();
1355
+ for (size_t I = 0 ; I < L.size (); I++)
1356
+ ArrayInitExprs.push_back (L[I]);
1357
+ }
1358
+ ArrayInitExprs.push_back (MemberInit.get ());
1359
+ Expr *ILE = new (SemaRef.getASTContext ())
1360
+ InitListExpr (SemaRef.getASTContext (), SourceLocation (), ArrayInitExprs,
1361
+ SourceLocation ());
1362
+ ILE->setType (FD->getType ());
1363
+ InitExprs.push_back (ILE);
1352
1364
}
1353
1365
1354
1366
void createSpecialMethodCall (const CXXRecordDecl *SpecialClass, Expr *Base,
@@ -1479,18 +1491,17 @@ class SyclKernelBodyCreator
1479
1491
return true ;
1480
1492
}
1481
1493
1494
+ // FIXME Remove this function when structs are replaced by their fields
1482
1495
bool handleStructType (FieldDecl *FD, QualType FieldTy) final {
1483
1496
createExprForStructOrScalar (FD);
1484
1497
return true ;
1485
1498
}
1486
1499
1487
1500
bool handleScalarType (FieldDecl *FD, QualType FieldTy) final {
1488
- createExprForStructOrScalar (FD);
1489
- return true ;
1490
- }
1491
-
1492
- bool handleArrayType (FieldDecl *FD, QualType FieldTy) final {
1493
- createExprForArray (FD);
1501
+ if (dyn_cast<ArraySubscriptExpr>(MemberExprBases.back ()))
1502
+ createExprForScalarElement (FD, FieldTy);
1503
+ else
1504
+ createExprForStructOrScalar (FD);
1494
1505
return true ;
1495
1506
}
1496
1507
@@ -1531,6 +1542,7 @@ class SyclKernelBodyCreator
1531
1542
1532
1543
using SyclKernelFieldHandler::enterArray;
1533
1544
using SyclKernelFieldHandler::enterField;
1545
+ using SyclKernelFieldHandler::handleScalarType;
1534
1546
using SyclKernelFieldHandler::leaveField;
1535
1547
};
1536
1548
@@ -1547,13 +1559,9 @@ class SyclKernelIntHeaderCreator
1547
1559
uint64_t Size;
1548
1560
const ConstantArrayType *CAT =
1549
1561
SemaRef.getASTContext ().getAsConstantArrayType (ArgTy);
1550
- if (CAT) {
1551
- QualType ET = CAT->getElementType ();
1552
- Size = static_cast <size_t >(CAT->getSize ().getZExtValue ()) *
1553
- SemaRef.getASTContext ().getTypeSizeInChars (ET).getQuantity ();
1554
- } else {
1555
- Size = SemaRef.getASTContext ().getTypeSizeInChars (ArgTy).getQuantity ();
1556
- }
1562
+ if (CAT)
1563
+ ArgTy = CAT->getElementType ();
1564
+ Size = SemaRef.getASTContext ().getTypeSizeInChars (ArgTy).getQuantity ();
1557
1565
Header.addParamDesc (Kind, static_cast <unsigned >(Size),
1558
1566
static_cast <unsigned >(CurOffset));
1559
1567
}
@@ -1630,12 +1638,7 @@ class SyclKernelIntHeaderCreator
1630
1638
return true ;
1631
1639
}
1632
1640
1633
- bool handleArrayType (FieldDecl *FD, QualType FieldTy) final {
1634
- wrapAnArray (SemaRef.getASTContext (), FieldTy, FD);
1635
- addParam (FD, FD->getType (), SYCLIntegrationHeader::kind_std_layout);
1636
- return true ;
1637
- }
1638
-
1641
+ // FIXME Remove this function when structs are replaced by their fields
1639
1642
bool handleStructType (FieldDecl *FD, QualType FieldTy) final {
1640
1643
addParam (FD, FieldTy, SYCLIntegrationHeader::kind_std_layout);
1641
1644
return true ;
@@ -1692,6 +1695,8 @@ class SyclKernelIntHeaderCreator
1692
1695
}
1693
1696
CurOffset -= ArraySize;
1694
1697
}
1698
+
1699
+ using SyclKernelFieldHandler::handleScalarType;
1695
1700
};
1696
1701
} // namespace
1697
1702
0 commit comments