@@ -9828,6 +9828,23 @@ buildPreInits(ASTContext &Context,
9828
9828
return nullptr;
9829
9829
}
9830
9830
9831
+ /// Build pre-init statement for the given statements.
9832
+ static Stmt *buildPreInits(ASTContext &Context, ArrayRef<Stmt *> PreInits) {
9833
+ if (!PreInits.empty()) {
9834
+ SmallVector<Stmt *> Stmts;
9835
+ for (Stmt *S : PreInits) {
9836
+ // Do not nest CompoundStmts.
9837
+ if (auto *CS = dyn_cast<CompoundStmt>(S)) {
9838
+ llvm::append_range(Stmts, CS->body());
9839
+ continue;
9840
+ }
9841
+ Stmts.push_back(S);
9842
+ }
9843
+ return CompoundStmt::Create(Context, PreInits, FPOptionsOverride(), {}, {});
9844
+ }
9845
+ return nullptr;
9846
+ }
9847
+
9831
9848
/// Build postupdate expression for the given list of postupdates expressions.
9832
9849
static Expr *buildPostUpdate(Sema &S, ArrayRef<Expr *> PostUpdates) {
9833
9850
Expr *PostUpdate = nullptr;
@@ -9924,11 +9941,24 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr *CollapseLoopCountExpr,
9924
9941
Stmt *DependentPreInits = Transform->getPreInits();
9925
9942
if (!DependentPreInits)
9926
9943
return;
9927
- for (Decl *C : cast<DeclStmt>(DependentPreInits)->getDeclGroup()) {
9928
- auto *D = cast<VarDecl>(C);
9929
- DeclRefExpr *Ref = buildDeclRefExpr(SemaRef, D, D->getType(),
9930
- Transform->getBeginLoc());
9931
- Captures[Ref] = Ref;
9944
+
9945
+ // Search for pre-init declared variables that need to be captured
9946
+ // to be referenceable inside the directive.
9947
+ SmallVector<Stmt *> Constituents;
9948
+ if (auto *CS = dyn_cast<CompoundStmt>(DependentPreInits))
9949
+ llvm::append_range(Constituents, CS->body());
9950
+ else
9951
+ Constituents.push_back(DependentPreInits);
9952
+ for (Stmt *S : Constituents) {
9953
+ if (DeclStmt *DC = dyn_cast<DeclStmt>(S)) {
9954
+ for (Decl *C : DC->decls()) {
9955
+ auto *D = cast<VarDecl>(C);
9956
+ DeclRefExpr *Ref = buildDeclRefExpr(
9957
+ SemaRef, D, D->getType().getNonReferenceType(),
9958
+ Transform->getBeginLoc());
9959
+ Captures[Ref] = Ref;
9960
+ }
9961
+ }
9932
9962
}
9933
9963
}))
9934
9964
return 0;
@@ -15059,9 +15089,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeSimdDirective(
15059
15089
bool SemaOpenMP::checkTransformableLoopNest(
15060
15090
OpenMPDirectiveKind Kind, Stmt *AStmt, int NumLoops,
15061
15091
SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers,
15062
- Stmt *&Body,
15063
- SmallVectorImpl<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>>
15064
- &OriginalInits) {
15092
+ Stmt *&Body, SmallVectorImpl<SmallVector<Stmt *, 0>> &OriginalInits) {
15065
15093
OriginalInits.emplace_back();
15066
15094
bool Result = OMPLoopBasedDirective::doForAllLoops(
15067
15095
AStmt->IgnoreContainers(), /*TryImperfectlyNestedLoops=*/false, NumLoops,
@@ -15097,14 +15125,75 @@ bool SemaOpenMP::checkTransformableLoopNest(
15097
15125
llvm_unreachable("Unhandled loop transformation");
15098
15126
if (!DependentPreInits)
15099
15127
return;
15100
- llvm::append_range(OriginalInits.back(),
15101
- cast<DeclStmt>(DependentPreInits)->getDeclGroup());
15128
+ // CompoundStmts are used as lists of other statements, add their
15129
+ // contents, not the lists themselves to avoid nesting. This is
15130
+ // necessary because DeclStmts need to be visible after the pre-init.
15131
+ else if (auto *CS = dyn_cast<CompoundStmt>(DependentPreInits))
15132
+ llvm::append_range(OriginalInits.back(), CS->body());
15133
+ else
15134
+ OriginalInits.back().push_back(DependentPreInits);
15102
15135
});
15103
15136
assert(OriginalInits.back().empty() && "No preinit after innermost loop");
15104
15137
OriginalInits.pop_back();
15105
15138
return Result;
15106
15139
}
15107
15140
15141
+ /// Add preinit statements that need to be propageted from the selected loop.
15142
+ static void addLoopPreInits(ASTContext &Context,
15143
+ OMPLoopBasedDirective::HelperExprs &LoopHelper,
15144
+ Stmt *LoopStmt, ArrayRef<Stmt *> OriginalInit,
15145
+ SmallVectorImpl<Stmt *> &PreInits) {
15146
+
15147
+ // For range-based for-statements, ensure that their syntactic sugar is
15148
+ // executed by adding them as pre-init statements.
15149
+ if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt)) {
15150
+ Stmt *RangeInit = CXXRangeFor->getInit();
15151
+ if (RangeInit)
15152
+ PreInits.push_back(RangeInit);
15153
+
15154
+ DeclStmt *RangeStmt = CXXRangeFor->getRangeStmt();
15155
+ PreInits.push_back(new (Context) DeclStmt(RangeStmt->getDeclGroup(),
15156
+ RangeStmt->getBeginLoc(),
15157
+ RangeStmt->getEndLoc()));
15158
+
15159
+ DeclStmt *RangeEnd = CXXRangeFor->getEndStmt();
15160
+ PreInits.push_back(new (Context) DeclStmt(RangeEnd->getDeclGroup(),
15161
+ RangeEnd->getBeginLoc(),
15162
+ RangeEnd->getEndLoc()));
15163
+ }
15164
+
15165
+ llvm::append_range(PreInits, OriginalInit);
15166
+
15167
+ // List of OMPCapturedExprDecl, for __begin, __end, and NumIterations
15168
+ if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits)) {
15169
+ PreInits.push_back(new (Context) DeclStmt(
15170
+ PI->getDeclGroup(), PI->getBeginLoc(), PI->getEndLoc()));
15171
+ }
15172
+
15173
+ // Gather declarations for the data members used as counters.
15174
+ for (Expr *CounterRef : LoopHelper.Counters) {
15175
+ auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl();
15176
+ if (isa<OMPCapturedExprDecl>(CounterDecl))
15177
+ PreInits.push_back(new (Context) DeclStmt(
15178
+ DeclGroupRef(CounterDecl), SourceLocation(), SourceLocation()));
15179
+ }
15180
+ }
15181
+
15182
+ /// Collect the loop statements (ForStmt or CXXRangeForStmt) of the affected
15183
+ /// loop of a construct.
15184
+ static void collectLoopStmts(Stmt *AStmt, MutableArrayRef<Stmt *> LoopStmts) {
15185
+ size_t NumLoops = LoopStmts.size();
15186
+ OMPLoopBasedDirective::doForAllLoops(
15187
+ AStmt, /*TryImperfectlyNestedLoops=*/false, NumLoops,
15188
+ [LoopStmts](unsigned Cnt, Stmt *CurStmt) {
15189
+ assert(!LoopStmts[Cnt] && "Loop statement must not yet be assigned");
15190
+ LoopStmts[Cnt] = CurStmt;
15191
+ return false;
15192
+ });
15193
+ assert(llvm::all_of(LoopStmts, [](Stmt *LoopStmt) { return LoopStmt; }) &&
15194
+ "Expecting a loop statement for each affected loop");
15195
+ }
15196
+
15108
15197
StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
15109
15198
Stmt *AStmt,
15110
15199
SourceLocation StartLoc,
@@ -15126,8 +15215,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
15126
15215
// Verify and diagnose loop nest.
15127
15216
SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers(NumLoops);
15128
15217
Stmt *Body = nullptr;
15129
- SmallVector<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>, 4>
15130
- OriginalInits;
15218
+ SmallVector<SmallVector<Stmt *, 0>, 4> OriginalInits;
15131
15219
if (!checkTransformableLoopNest(OMPD_tile, AStmt, NumLoops, LoopHelpers, Body,
15132
15220
OriginalInits))
15133
15221
return StmtError();
@@ -15144,7 +15232,11 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
15144
15232
"Expecting loop iteration space dimensionality to match number of "
15145
15233
"affected loops");
15146
15234
15147
- SmallVector<Decl *, 4> PreInits;
15235
+ // Collect all affected loop statements.
15236
+ SmallVector<Stmt *> LoopStmts(NumLoops, nullptr);
15237
+ collectLoopStmts(AStmt, LoopStmts);
15238
+
15239
+ SmallVector<Stmt *, 4> PreInits;
15148
15240
CaptureVars CopyTransformer(SemaRef);
15149
15241
15150
15242
// Create iteration variables for the generated loops.
@@ -15184,20 +15276,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
15184
15276
&SemaRef.PP.getIdentifierTable().get(TileCntName));
15185
15277
TileIndVars[I] = TileCntDecl;
15186
15278
}
15187
- for (auto &P : OriginalInits[I]) {
15188
- if (auto *D = P.dyn_cast<Decl *>())
15189
- PreInits.push_back(D);
15190
- else if (auto *PI = dyn_cast_or_null<DeclStmt>(P.dyn_cast<Stmt *>()))
15191
- PreInits.append(PI->decl_begin(), PI->decl_end());
15192
- }
15193
- if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits))
15194
- PreInits.append(PI->decl_begin(), PI->decl_end());
15195
- // Gather declarations for the data members used as counters.
15196
- for (Expr *CounterRef : LoopHelper.Counters) {
15197
- auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl();
15198
- if (isa<OMPCapturedExprDecl>(CounterDecl))
15199
- PreInits.push_back(CounterDecl);
15200
- }
15279
+
15280
+ addLoopPreInits(Context, LoopHelper, LoopStmts[I], OriginalInits[I],
15281
+ PreInits);
15201
15282
}
15202
15283
15203
15284
// Once the original iteration values are set, append the innermost body.
@@ -15246,19 +15327,20 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
15246
15327
OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers[I];
15247
15328
Expr *NumIterations = LoopHelper.NumIterations;
15248
15329
auto *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters[0]);
15249
- QualType CntTy = OrigCntVar->getType();
15330
+ QualType IVTy = NumIterations->getType();
15331
+ Stmt *LoopStmt = LoopStmts[I];
15250
15332
15251
15333
// Commonly used variables. One of the constraints of an AST is that every
15252
15334
// node object must appear at most once, hence we define lamdas that create
15253
15335
// a new AST node at every use.
15254
- auto MakeTileIVRef = [&SemaRef = this->SemaRef, &TileIndVars, I, CntTy ,
15336
+ auto MakeTileIVRef = [&SemaRef = this->SemaRef, &TileIndVars, I, IVTy ,
15255
15337
OrigCntVar]() {
15256
- return buildDeclRefExpr(SemaRef, TileIndVars[I], CntTy ,
15338
+ return buildDeclRefExpr(SemaRef, TileIndVars[I], IVTy ,
15257
15339
OrigCntVar->getExprLoc());
15258
15340
};
15259
- auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, CntTy ,
15341
+ auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, IVTy ,
15260
15342
OrigCntVar]() {
15261
- return buildDeclRefExpr(SemaRef, FloorIndVars[I], CntTy ,
15343
+ return buildDeclRefExpr(SemaRef, FloorIndVars[I], IVTy ,
15262
15344
OrigCntVar->getExprLoc());
15263
15345
};
15264
15346
@@ -15320,6 +15402,8 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
15320
15402
// further into the inner loop.
15321
15403
SmallVector<Stmt *, 4> BodyParts;
15322
15404
BodyParts.append(LoopHelper.Updates.begin(), LoopHelper.Updates.end());
15405
+ if (auto *SourceCXXFor = dyn_cast<CXXForRangeStmt>(LoopStmt))
15406
+ BodyParts.push_back(SourceCXXFor->getLoopVarStmt());
15323
15407
BodyParts.push_back(Inner);
15324
15408
Inner = CompoundStmt::Create(Context, BodyParts, FPOptionsOverride(),
15325
15409
Inner->getBeginLoc(), Inner->getEndLoc());
@@ -15334,12 +15418,14 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
15334
15418
auto &LoopHelper = LoopHelpers[I];
15335
15419
Expr *NumIterations = LoopHelper.NumIterations;
15336
15420
DeclRefExpr *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters[0]);
15337
- QualType CntTy = OrigCntVar ->getType();
15421
+ QualType IVTy = NumIterations ->getType();
15338
15422
15339
- // Commonly used variables.
15340
- auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, CntTy,
15423
+ // Commonly used variables. One of the constraints of an AST is that every
15424
+ // node object must appear at most once, hence we define lamdas that create
15425
+ // a new AST node at every use.
15426
+ auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, IVTy,
15341
15427
OrigCntVar]() {
15342
- return buildDeclRefExpr(SemaRef, FloorIndVars[I], CntTy ,
15428
+ return buildDeclRefExpr(SemaRef, FloorIndVars[I], IVTy ,
15343
15429
OrigCntVar->getExprLoc());
15344
15430
};
15345
15431
@@ -15405,8 +15491,7 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
15405
15491
Stmt *Body = nullptr;
15406
15492
SmallVector<OMPLoopBasedDirective::HelperExprs, NumLoops> LoopHelpers(
15407
15493
NumLoops);
15408
- SmallVector<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>, NumLoops + 1>
15409
- OriginalInits;
15494
+ SmallVector<SmallVector<Stmt *, 0>, NumLoops + 1> OriginalInits;
15410
15495
if (!checkTransformableLoopNest(OMPD_unroll, AStmt, NumLoops, LoopHelpers,
15411
15496
Body, OriginalInits))
15412
15497
return StmtError();
@@ -15418,6 +15503,10 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
15418
15503
return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
15419
15504
NumGeneratedLoops, nullptr, nullptr);
15420
15505
15506
+ assert(LoopHelpers.size() == NumLoops &&
15507
+ "Expecting a single-dimensional loop iteration space");
15508
+ assert(OriginalInits.size() == NumLoops &&
15509
+ "Expecting a single-dimensional loop iteration space");
15421
15510
OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front();
15422
15511
15423
15512
if (FullClause) {
@@ -15481,24 +15570,13 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
15481
15570
// of a canonical loop nest where these PreInits are emitted before the
15482
15571
// outermost directive.
15483
15572
15573
+ // Find the loop statement.
15574
+ Stmt *LoopStmt = nullptr;
15575
+ collectLoopStmts(AStmt, {LoopStmt});
15576
+
15484
15577
// Determine the PreInit declarations.
15485
- SmallVector<Decl *, 4> PreInits;
15486
- assert(OriginalInits.size() == 1 &&
15487
- "Expecting a single-dimensional loop iteration space");
15488
- for (auto &P : OriginalInits[0]) {
15489
- if (auto *D = P.dyn_cast<Decl *>())
15490
- PreInits.push_back(D);
15491
- else if (auto *PI = dyn_cast_or_null<DeclStmt>(P.dyn_cast<Stmt *>()))
15492
- PreInits.append(PI->decl_begin(), PI->decl_end());
15493
- }
15494
- if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits))
15495
- PreInits.append(PI->decl_begin(), PI->decl_end());
15496
- // Gather declarations for the data members used as counters.
15497
- for (Expr *CounterRef : LoopHelper.Counters) {
15498
- auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl();
15499
- if (isa<OMPCapturedExprDecl>(CounterDecl))
15500
- PreInits.push_back(CounterDecl);
15501
- }
15578
+ SmallVector<Stmt *, 4> PreInits;
15579
+ addLoopPreInits(Context, LoopHelper, LoopStmt, OriginalInits[0], PreInits);
15502
15580
15503
15581
auto *IterationVarRef = cast<DeclRefExpr>(LoopHelper.IterationVarRef);
15504
15582
QualType IVTy = IterationVarRef->getType();
@@ -15604,6 +15682,8 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
15604
15682
// Inner For statement.
15605
15683
SmallVector<Stmt *> InnerBodyStmts;
15606
15684
InnerBodyStmts.append(LoopHelper.Updates.begin(), LoopHelper.Updates.end());
15685
+ if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt))
15686
+ InnerBodyStmts.push_back(CXXRangeFor->getLoopVarStmt());
15607
15687
InnerBodyStmts.push_back(Body);
15608
15688
CompoundStmt *InnerBody =
15609
15689
CompoundStmt::Create(getASTContext(), InnerBodyStmts, FPOptionsOverride(),
0 commit comments