Skip to content

Commit 0412db3

Browse files
committed
Array elements are now passed as individual parameters.
1 parent d87b2cc commit 0412db3

File tree

1 file changed

+76
-71
lines changed

1 file changed

+76
-71
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 76 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -688,22 +688,6 @@ static ParamDesc makeParamDesc(ASTContext &Ctx, const CXXBaseSpecifier &Src,
688688
Ctx.getTrivialTypeSourceInfo(Ty));
689689
}
690690

691-
// Create a new class around a field - used to wrap arrays.
692-
static RecordDecl *wrapAnArray(ASTContext &Ctx, const QualType ArgTy,
693-
FieldDecl *&Field) {
694-
RecordDecl *NewClass = Ctx.buildImplicitRecord("wrapped_array");
695-
NewClass->startDefinition();
696-
Field = FieldDecl::Create(
697-
Ctx, NewClass, SourceLocation(), SourceLocation(),
698-
/*Id=*/nullptr, ArgTy,
699-
Ctx.getTrivialTypeSourceInfo(ArgTy, SourceLocation()),
700-
/*BW=*/nullptr, /*Mutable=*/false, /*InitStyle=*/ICIS_NoInit);
701-
Field->setAccess(AS_public);
702-
NewClass->addDecl(Field);
703-
NewClass->completeDefinition();
704-
return NewClass;
705-
}
706-
707691
/// \return the target of given SYCL accessor type
708692
static target getAccessTarget(const ClassTemplateSpecializationDecl *AccTy) {
709693
return static_cast<target>(
@@ -799,15 +783,21 @@ static void VisitField(CXXRecordDecl *Owner, RangeTy &&Item, QualType ItemTy,
799783
Handlers &... handlers) {
800784
if (Util::isSyclAccessorType(ItemTy))
801785
KF_FOR_EACH(handleSyclAccessorType, Item, ItemTy);
802-
if (Util::isSyclStreamType(ItemTy))
786+
else if (Util::isSyclStreamType(ItemTy))
803787
KF_FOR_EACH(handleSyclStreamType, Item, ItemTy);
804-
if (ItemTy->isStructureOrClassType())
788+
else if (ItemTy->isStructureOrClassType())
805789
VisitAccessorWrapper(Owner, Item, ItemTy->getAsCXXRecordDecl(),
806790
handlers...);
807-
if (ItemTy->isArrayType())
791+
else if (ItemTy->isArrayType())
808792
VisitArrayElements(Item, ItemTy, handlers...);
809793
}
810794

795+
template <typename RangeTy, typename... Handlers>
796+
static void VisitScalarField(CXXRecordDecl *Owner, RangeTy &&Item, QualType ItemTy,
797+
Handlers &... handlers) {
798+
KF_FOR_EACH(handleScalarType, Item, ItemTy);
799+
}
800+
811801
template <typename RangeTy, typename... Handlers>
812802
static void VisitArrayElements(RangeTy Item, QualType FieldTy,
813803
Handlers &... handlers) {
@@ -816,7 +806,10 @@ static void VisitArrayElements(RangeTy Item, QualType FieldTy,
816806
int64_t ElemCount = CAT->getSize().getSExtValue();
817807
std::initializer_list<int>{(handlers.enterArray(), 0)...};
818808
for (int64_t Count = 0; Count < ElemCount; Count++) {
819-
VisitField(nullptr, Item, ET, handlers...);
809+
if (ET->isScalarType())
810+
VisitScalarField(nullptr, Item, ET, handlers...);
811+
else
812+
VisitField(nullptr, Item, ET, handlers...);
820813
(void)std::initializer_list<int>{(handlers.nextElement(ET), 0)...};
821814
}
822815
(void)std::initializer_list<int>{(handlers.leaveArray(ET, ElemCount), 0)...};
@@ -919,6 +912,9 @@ template <typename Derived> class SyclKernelFieldHandler {
919912
virtual bool handleReferenceType(FieldDecl *, QualType) { return true; }
920913
virtual bool handlePointerType(FieldDecl *, QualType) { return true; }
921914
virtual bool handleArrayType(FieldDecl *, QualType) { return true; }
915+
virtual bool handleScalarType(const CXXBaseSpecifier &, QualType) {
916+
return true;
917+
}
922918
virtual bool handleScalarType(FieldDecl *, QualType) { return true; }
923919
// Most handlers shouldn't be handling this, just the field checker.
924920
virtual bool handleOtherType(FieldDecl *, QualType) { return true; }
@@ -1003,7 +999,8 @@ class SyclKernelFieldChecker
1003999

10041000
public:
10051001
SyclKernelFieldChecker(Sema &S)
1006-
: SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {}
1002+
: SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {
1003+
}
10071004
bool isValid() { return !IsInvalid; }
10081005

10091006
bool handleReferenceType(FieldDecl *FD, QualType FieldTy) final {
@@ -1052,6 +1049,10 @@ class SyclKernelDeclCreator
10521049
size_t LastParamIndex = 0;
10531050

10541051
void addParam(const FieldDecl *FD, QualType FieldTy) {
1052+
const ConstantArrayType *CAT =
1053+
SemaRef.getASTContext().getAsConstantArrayType(FieldTy);
1054+
if (CAT)
1055+
FieldTy = CAT->getElementType();
10551056
ParamDesc newParamDesc = makeParamDesc(FD, FieldTy);
10561057
addParam(newParamDesc, FieldTy);
10571058
}
@@ -1068,7 +1069,6 @@ class SyclKernelDeclCreator
10681069
SemaRef.getASTContext(), KernelDecl, SourceLocation(), SourceLocation(),
10691070
std::get<1>(newParamDesc), std::get<0>(newParamDesc),
10701071
std::get<2>(newParamDesc), SC_None, /*DefArg*/ nullptr);
1071-
10721072
NewParam->setScopeInfo(0, Params.size());
10731073
NewParam->setIsUsed();
10741074

@@ -1131,7 +1131,8 @@ class SyclKernelDeclCreator
11311131
: SyclKernelFieldHandler(S),
11321132
KernelDecl(createKernelDecl(S.getASTContext(), Name, Loc, IsInline,
11331133
IsSIMDKernel)),
1134-
ArgChecker(ArgChecker), FuncContext(SemaRef, KernelDecl) {}
1134+
ArgChecker(ArgChecker), FuncContext(SemaRef, KernelDecl) {
1135+
}
11351136

11361137
~SyclKernelDeclCreator() {
11371138
ASTContext &Ctx = SemaRef.getASTContext();
@@ -1189,18 +1190,12 @@ class SyclKernelDeclCreator
11891190
return true;
11901191
}
11911192

1192-
bool handleArrayType(FieldDecl *FD, QualType FieldTy) final {
1193-
RecordDecl *NewClass = wrapAnArray(SemaRef.getASTContext(), FieldTy, FD);
1194-
QualType ST = SemaRef.getASTContext().getRecordType(NewClass);
1195-
addParam(FD, ST);
1196-
return true;
1197-
}
1198-
11991193
bool handleScalarType(FieldDecl *FD, QualType FieldTy) final {
12001194
addParam(FD, FieldTy);
12011195
return true;
12021196
}
12031197

1198+
//FIXME Remove this function when structs are replaced by their fields
12041199
bool handleStructType(FieldDecl *FD, QualType FieldTy) final {
12051200
addParam(FD, FieldTy);
12061201
return true;
@@ -1225,6 +1220,8 @@ class SyclKernelDeclCreator
12251220
return ArrayRef<ParmVarDecl *>(std::begin(Params) + LastParamIndex,
12261221
std::end(Params));
12271222
}
1223+
1224+
using SyclKernelFieldHandler::handleScalarType;
12281225
};
12291226

12301227
class SyclKernelBodyCreator
@@ -1309,11 +1306,9 @@ class SyclKernelBodyCreator
13091306
return Result;
13101307
}
13111308

1312-
void createExprForStructOrScalar(FieldDecl *FD) {
1309+
Expr *createInitExpr(FieldDecl *FD) {
13131310
ParmVarDecl *KernelParameter =
13141311
DeclCreator.getParamVarDeclsForCurrentField()[0];
1315-
InitializedEntity Entity =
1316-
InitializedEntity::InitializeMember(FD, &VarEntity);
13171312
QualType ParamType = KernelParameter->getOriginalType();
13181313
Expr *DRE = SemaRef.BuildDeclRefExpr(KernelParameter, ParamType, VK_LValue,
13191314
SourceLocation());
@@ -1323,32 +1318,49 @@ class SyclKernelBodyCreator
13231318
DRE = ImplicitCastExpr::Create(SemaRef.Context, FD->getType(),
13241319
CK_AddressSpaceConversion, DRE, nullptr,
13251320
VK_RValue);
1321+
return DRE;
1322+
}
1323+
1324+
void createExprForStructOrScalar(FieldDecl *FD) {
1325+
InitializedEntity Entity =
1326+
InitializedEntity::InitializeMember(FD, &VarEntity);
13261327
InitializationKind InitKind =
13271328
InitializationKind::CreateCopy(SourceLocation(), SourceLocation());
1329+
Expr *DRE = createInitExpr(FD);
13281330
InitializationSequence InitSeq(SemaRef, Entity, InitKind, DRE);
1329-
13301331
ExprResult MemberInit = InitSeq.Perform(SemaRef, Entity, InitKind, DRE);
13311332
InitExprs.push_back(MemberInit.get());
13321333
}
13331334

1334-
void createExprForArray(FieldDecl *FD) {
1335-
ParmVarDecl *KernelParameter =
1336-
DeclCreator.getParamVarDeclsForCurrentField()[0];
1337-
QualType ParamType = KernelParameter->getOriginalType();
1338-
CXXRecordDecl *WrapperStruct = ParamType->getAsCXXRecordDecl();
1339-
// The first and only field of the wrapper struct is the array
1340-
FieldDecl *Array = *(WrapperStruct->field_begin());
1341-
Expr *DRE = SemaRef.BuildDeclRefExpr(KernelParameter, ParamType, VK_LValue,
1342-
SourceLocation());
1343-
Expr *InitExpr = BuildMemberExpr(DRE, Array);
1344-
InitializationKind InitKind = InitializationKind::CreateDirect(
1345-
SourceLocation(), SourceLocation(), SourceLocation());
1346-
InitializedEntity Entity = InitializedEntity::InitializeLambdaCapture(
1347-
nullptr, Array->getType(), SourceLocation());
1348-
InitializationSequence InitSeq(SemaRef, Entity, InitKind, InitExpr);
1349-
ExprResult MemberInit =
1350-
InitSeq.Perform(SemaRef, Entity, InitKind, InitExpr);
1351-
InitExprs.push_back(MemberInit.get());
1335+
void createExprForScalarElement(FieldDecl *FD, QualType FieldTy) {
1336+
InitializedEntity ArrayEntity =
1337+
InitializedEntity::InitializeMember(FD, &VarEntity);
1338+
InitializationKind InitKind =
1339+
InitializationKind::CreateCopy(SourceLocation(), SourceLocation());
1340+
Expr *DRE = createInitExpr(FD);
1341+
Expr *Idx = dyn_cast<ArraySubscriptExpr>(MemberExprBases.back())->getIdx();
1342+
llvm::APSInt Result;
1343+
SemaRef.VerifyIntegerConstantExpression(Idx, &Result);
1344+
uint64_t IntIdx = Result.getZExtValue();
1345+
InitializedEntity Entity = InitializedEntity::InitializeElement(
1346+
SemaRef.getASTContext(), IntIdx, ArrayEntity);
1347+
InitializationSequence InitSeq(SemaRef, Entity, InitKind, DRE);
1348+
ExprResult MemberInit = InitSeq.Perform(SemaRef, Entity, InitKind, DRE);
1349+
llvm::SmallVector<Expr *, 16> ArrayInitExprs;
1350+
if (IntIdx > 0) {
1351+
// Continue with the current InitList
1352+
InitListExpr *ILE = cast<InitListExpr>(InitExprs.back());
1353+
InitExprs.pop_back();
1354+
llvm::ArrayRef<Expr *> L = ILE->inits();
1355+
for (size_t I = 0; I < L.size(); I++)
1356+
ArrayInitExprs.push_back(L[I]);
1357+
}
1358+
ArrayInitExprs.push_back(MemberInit.get());
1359+
Expr *ILE = new (SemaRef.getASTContext())
1360+
InitListExpr(SemaRef.getASTContext(), SourceLocation(), ArrayInitExprs,
1361+
SourceLocation());
1362+
ILE->setType(FD->getType());
1363+
InitExprs.push_back(ILE);
13521364
}
13531365

13541366
void createSpecialMethodCall(const CXXRecordDecl *SpecialClass, Expr *Base,
@@ -1479,18 +1491,17 @@ class SyclKernelBodyCreator
14791491
return true;
14801492
}
14811493

1494+
//FIXME Remove this function when structs are replaced by their fields
14821495
bool handleStructType(FieldDecl *FD, QualType FieldTy) final {
14831496
createExprForStructOrScalar(FD);
14841497
return true;
14851498
}
14861499

14871500
bool handleScalarType(FieldDecl *FD, QualType FieldTy) final {
1488-
createExprForStructOrScalar(FD);
1489-
return true;
1490-
}
1491-
1492-
bool handleArrayType(FieldDecl *FD, QualType FieldTy) final {
1493-
createExprForArray(FD);
1501+
if (dyn_cast<ArraySubscriptExpr>(MemberExprBases.back()))
1502+
createExprForScalarElement(FD, FieldTy);
1503+
else
1504+
createExprForStructOrScalar(FD);
14941505
return true;
14951506
}
14961507

@@ -1531,6 +1542,7 @@ class SyclKernelBodyCreator
15311542

15321543
using SyclKernelFieldHandler::enterArray;
15331544
using SyclKernelFieldHandler::enterField;
1545+
using SyclKernelFieldHandler::handleScalarType;
15341546
using SyclKernelFieldHandler::leaveField;
15351547
};
15361548

@@ -1547,13 +1559,9 @@ class SyclKernelIntHeaderCreator
15471559
uint64_t Size;
15481560
const ConstantArrayType *CAT =
15491561
SemaRef.getASTContext().getAsConstantArrayType(ArgTy);
1550-
if (CAT) {
1551-
QualType ET = CAT->getElementType();
1552-
Size = static_cast<size_t>(CAT->getSize().getZExtValue()) *
1553-
SemaRef.getASTContext().getTypeSizeInChars(ET).getQuantity();
1554-
} else {
1555-
Size = SemaRef.getASTContext().getTypeSizeInChars(ArgTy).getQuantity();
1556-
}
1562+
if (CAT)
1563+
ArgTy = CAT->getElementType();
1564+
Size = SemaRef.getASTContext().getTypeSizeInChars(ArgTy).getQuantity();
15571565
Header.addParamDesc(Kind, static_cast<unsigned>(Size),
15581566
static_cast<unsigned>(CurOffset));
15591567
}
@@ -1630,12 +1638,7 @@ class SyclKernelIntHeaderCreator
16301638
return true;
16311639
}
16321640

1633-
bool handleArrayType(FieldDecl *FD, QualType FieldTy) final {
1634-
wrapAnArray(SemaRef.getASTContext(), FieldTy, FD);
1635-
addParam(FD, FD->getType(), SYCLIntegrationHeader::kind_std_layout);
1636-
return true;
1637-
}
1638-
1641+
//FIXME Remove this function when structs are replaced by their fields
16391642
bool handleStructType(FieldDecl *FD, QualType FieldTy) final {
16401643
addParam(FD, FieldTy, SYCLIntegrationHeader::kind_std_layout);
16411644
return true;
@@ -1692,6 +1695,8 @@ class SyclKernelIntHeaderCreator
16921695
}
16931696
CurOffset -= ArraySize;
16941697
}
1698+
1699+
using SyclKernelFieldHandler::handleScalarType;
16951700
};
16961701
} // namespace
16971702

0 commit comments

Comments
 (0)