Skip to content

Commit d6057c4

Browse files
committed
[Clang][OpenMP] Fix tile/unroll on iterator- and foreach-loops
1 parent 7efafb0 commit d6057c4

18 files changed

+2506
-445
lines changed

clang/include/clang/Sema/SemaOpenMP.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,9 +1390,7 @@ class SemaOpenMP : public SemaBase {
13901390
bool checkTransformableLoopNest(
13911391
OpenMPDirectiveKind Kind, Stmt *AStmt, int NumLoops,
13921392
SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers,
1393-
Stmt *&Body,
1394-
SmallVectorImpl<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>>
1395-
&OriginalInits);
1393+
Stmt *&Body, SmallVectorImpl<SmallVector<Stmt *, 0>> &OriginalInits);
13961394

13971395
/// Helper to keep information about the current `omp begin/end declare
13981396
/// variant` nesting.

clang/lib/CodeGen/CGStmtOpenMP.cpp

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class OMPTeamsScope final : public OMPLexicalScope {
142142
/// of used expression from loop statement.
143143
class OMPLoopScope : public CodeGenFunction::RunCleanupsScope {
144144
void emitPreInitStmt(CodeGenFunction &CGF, const OMPLoopBasedDirective &S) {
145-
const DeclStmt *PreInits;
145+
const Stmt *PreInits;
146146
CodeGenFunction::OMPMapVars PreCondVars;
147147
if (auto *LD = dyn_cast<OMPLoopDirective>(&S)) {
148148
llvm::DenseSet<const VarDecl *> EmittedAsPrivate;
@@ -182,17 +182,34 @@ class OMPLoopScope : public CodeGenFunction::RunCleanupsScope {
182182
}
183183
return false;
184184
});
185-
PreInits = cast_or_null<DeclStmt>(LD->getPreInits());
185+
PreInits = LD->getPreInits();
186186
} else if (const auto *Tile = dyn_cast<OMPTileDirective>(&S)) {
187-
PreInits = cast_or_null<DeclStmt>(Tile->getPreInits());
187+
PreInits = Tile->getPreInits();
188188
} else if (const auto *Unroll = dyn_cast<OMPUnrollDirective>(&S)) {
189-
PreInits = cast_or_null<DeclStmt>(Unroll->getPreInits());
189+
PreInits = Unroll->getPreInits();
190190
} else {
191191
llvm_unreachable("Unknown loop-based directive kind.");
192192
}
193193
if (PreInits) {
194-
for (const auto *I : PreInits->decls())
195-
CGF.EmitVarDecl(cast<VarDecl>(*I));
194+
// CompoundStmts and DeclStmts are used as lists of PreInit statements and
195+
// declarations. Since declarations must be visible in the the following
196+
// that they initialize, unpack the ComboundStmt they are nested in.
197+
SmallVector<const Stmt *> PreInitStmts;
198+
if (auto *PreInitCompound = dyn_cast<CompoundStmt>(PreInits))
199+
llvm::append_range(PreInitStmts, PreInitCompound->body());
200+
else
201+
PreInitStmts.push_back(PreInits);
202+
203+
for (const Stmt *S : PreInitStmts) {
204+
// EmitStmt skips any OMPCapturedExprDecls, but needs to be emitted
205+
// here.
206+
if (auto *PreInitDecl = dyn_cast<DeclStmt>(S)) {
207+
for (Decl *I : PreInitDecl->decls())
208+
CGF.EmitVarDecl(cast<VarDecl>(*I));
209+
continue;
210+
}
211+
CGF.EmitStmt(S);
212+
}
196213
}
197214
PreCondVars.restore(CGF);
198215
}

clang/lib/Sema/SemaOpenMP.cpp

Lines changed: 135 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9828,6 +9828,23 @@ buildPreInits(ASTContext &Context,
98289828
return nullptr;
98299829
}
98309830

9831+
/// Build pre-init statement for the given statements.
9832+
static Stmt *buildPreInits(ASTContext &Context, ArrayRef<Stmt *> PreInits) {
9833+
if (!PreInits.empty()) {
9834+
SmallVector<Stmt *> Stmts;
9835+
for (Stmt *S : PreInits) {
9836+
// Do not nest CompoundStmts.
9837+
if (auto *CS = dyn_cast<CompoundStmt>(S)) {
9838+
llvm::append_range(Stmts, CS->body());
9839+
continue;
9840+
}
9841+
Stmts.push_back(S);
9842+
}
9843+
return CompoundStmt::Create(Context, PreInits, FPOptionsOverride(), {}, {});
9844+
}
9845+
return nullptr;
9846+
}
9847+
98319848
/// Build postupdate expression for the given list of postupdates expressions.
98329849
static Expr *buildPostUpdate(Sema &S, ArrayRef<Expr *> PostUpdates) {
98339850
Expr *PostUpdate = nullptr;
@@ -9924,11 +9941,24 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr *CollapseLoopCountExpr,
99249941
Stmt *DependentPreInits = Transform->getPreInits();
99259942
if (!DependentPreInits)
99269943
return;
9927-
for (Decl *C : cast<DeclStmt>(DependentPreInits)->getDeclGroup()) {
9928-
auto *D = cast<VarDecl>(C);
9929-
DeclRefExpr *Ref = buildDeclRefExpr(SemaRef, D, D->getType(),
9930-
Transform->getBeginLoc());
9931-
Captures[Ref] = Ref;
9944+
9945+
// Search for pre-init declared variables that need to be captured
9946+
// to be referenceable inside the directive.
9947+
SmallVector<Stmt *> Constituents;
9948+
if (auto *CS = dyn_cast<CompoundStmt>(DependentPreInits))
9949+
llvm::append_range(Constituents, CS->body());
9950+
else
9951+
Constituents.push_back(DependentPreInits);
9952+
for (Stmt *S : Constituents) {
9953+
if (DeclStmt *DC = dyn_cast<DeclStmt>(S)) {
9954+
for (Decl *C : DC->decls()) {
9955+
auto *D = cast<VarDecl>(C);
9956+
DeclRefExpr *Ref = buildDeclRefExpr(
9957+
SemaRef, D, D->getType().getNonReferenceType(),
9958+
Transform->getBeginLoc());
9959+
Captures[Ref] = Ref;
9960+
}
9961+
}
99329962
}
99339963
}))
99349964
return 0;
@@ -15059,9 +15089,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeSimdDirective(
1505915089
bool SemaOpenMP::checkTransformableLoopNest(
1506015090
OpenMPDirectiveKind Kind, Stmt *AStmt, int NumLoops,
1506115091
SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers,
15062-
Stmt *&Body,
15063-
SmallVectorImpl<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>>
15064-
&OriginalInits) {
15092+
Stmt *&Body, SmallVectorImpl<SmallVector<Stmt *, 0>> &OriginalInits) {
1506515093
OriginalInits.emplace_back();
1506615094
bool Result = OMPLoopBasedDirective::doForAllLoops(
1506715095
AStmt->IgnoreContainers(), /*TryImperfectlyNestedLoops=*/false, NumLoops,
@@ -15097,14 +15125,75 @@ bool SemaOpenMP::checkTransformableLoopNest(
1509715125
llvm_unreachable("Unhandled loop transformation");
1509815126
if (!DependentPreInits)
1509915127
return;
15100-
llvm::append_range(OriginalInits.back(),
15101-
cast<DeclStmt>(DependentPreInits)->getDeclGroup());
15128+
// CompoundStmts are used as lists of other statements, add their
15129+
// contents, not the lists themselves to avoid nesting. This is
15130+
// necessary because DeclStmts need to be visible after the pre-init.
15131+
else if (auto *CS = dyn_cast<CompoundStmt>(DependentPreInits))
15132+
llvm::append_range(OriginalInits.back(), CS->body());
15133+
else
15134+
OriginalInits.back().push_back(DependentPreInits);
1510215135
});
1510315136
assert(OriginalInits.back().empty() && "No preinit after innermost loop");
1510415137
OriginalInits.pop_back();
1510515138
return Result;
1510615139
}
1510715140

15141+
/// Add preinit statements that need to be propageted from the selected loop.
15142+
static void addLoopPreInits(ASTContext &Context,
15143+
OMPLoopBasedDirective::HelperExprs &LoopHelper,
15144+
Stmt *LoopStmt, ArrayRef<Stmt *> OriginalInit,
15145+
SmallVectorImpl<Stmt *> &PreInits) {
15146+
15147+
// For range-based for-statements, ensure that their syntactic sugar is
15148+
// executed by adding them as pre-init statements.
15149+
if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt)) {
15150+
Stmt *RangeInit = CXXRangeFor->getInit();
15151+
if (RangeInit)
15152+
PreInits.push_back(RangeInit);
15153+
15154+
DeclStmt *RangeStmt = CXXRangeFor->getRangeStmt();
15155+
PreInits.push_back(new (Context) DeclStmt(RangeStmt->getDeclGroup(),
15156+
RangeStmt->getBeginLoc(),
15157+
RangeStmt->getEndLoc()));
15158+
15159+
DeclStmt *RangeEnd = CXXRangeFor->getEndStmt();
15160+
PreInits.push_back(new (Context) DeclStmt(RangeEnd->getDeclGroup(),
15161+
RangeEnd->getBeginLoc(),
15162+
RangeEnd->getEndLoc()));
15163+
}
15164+
15165+
llvm::append_range(PreInits, OriginalInit);
15166+
15167+
// List of OMPCapturedExprDecl, for __begin, __end, and NumIterations
15168+
if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits)) {
15169+
PreInits.push_back(new (Context) DeclStmt(
15170+
PI->getDeclGroup(), PI->getBeginLoc(), PI->getEndLoc()));
15171+
}
15172+
15173+
// Gather declarations for the data members used as counters.
15174+
for (Expr *CounterRef : LoopHelper.Counters) {
15175+
auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl();
15176+
if (isa<OMPCapturedExprDecl>(CounterDecl))
15177+
PreInits.push_back(new (Context) DeclStmt(
15178+
DeclGroupRef(CounterDecl), SourceLocation(), SourceLocation()));
15179+
}
15180+
}
15181+
15182+
/// Collect the loop statements (ForStmt or CXXRangeForStmt) of the affected
15183+
/// loop of a construct.
15184+
static void collectLoopStmts(Stmt *AStmt, MutableArrayRef<Stmt *> LoopStmts) {
15185+
size_t NumLoops = LoopStmts.size();
15186+
OMPLoopBasedDirective::doForAllLoops(
15187+
AStmt, /*TryImperfectlyNestedLoops=*/false, NumLoops,
15188+
[LoopStmts](unsigned Cnt, Stmt *CurStmt) {
15189+
assert(!LoopStmts[Cnt] && "Loop statement must not yet be assigned");
15190+
LoopStmts[Cnt] = CurStmt;
15191+
return false;
15192+
});
15193+
assert(llvm::all_of(LoopStmts, [](Stmt *LoopStmt) { return LoopStmt; }) &&
15194+
"Expecting a loop statement for each affected loop");
15195+
}
15196+
1510815197
StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1510915198
Stmt *AStmt,
1511015199
SourceLocation StartLoc,
@@ -15126,8 +15215,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1512615215
// Verify and diagnose loop nest.
1512715216
SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers(NumLoops);
1512815217
Stmt *Body = nullptr;
15129-
SmallVector<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>, 4>
15130-
OriginalInits;
15218+
SmallVector<SmallVector<Stmt *, 0>, 4> OriginalInits;
1513115219
if (!checkTransformableLoopNest(OMPD_tile, AStmt, NumLoops, LoopHelpers, Body,
1513215220
OriginalInits))
1513315221
return StmtError();
@@ -15144,7 +15232,11 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1514415232
"Expecting loop iteration space dimensionality to match number of "
1514515233
"affected loops");
1514615234

15147-
SmallVector<Decl *, 4> PreInits;
15235+
// Collect all affected loop statements.
15236+
SmallVector<Stmt *> LoopStmts(NumLoops, nullptr);
15237+
collectLoopStmts(AStmt, LoopStmts);
15238+
15239+
SmallVector<Stmt *, 4> PreInits;
1514815240
CaptureVars CopyTransformer(SemaRef);
1514915241

1515015242
// Create iteration variables for the generated loops.
@@ -15184,20 +15276,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1518415276
&SemaRef.PP.getIdentifierTable().get(TileCntName));
1518515277
TileIndVars[I] = TileCntDecl;
1518615278
}
15187-
for (auto &P : OriginalInits[I]) {
15188-
if (auto *D = P.dyn_cast<Decl *>())
15189-
PreInits.push_back(D);
15190-
else if (auto *PI = dyn_cast_or_null<DeclStmt>(P.dyn_cast<Stmt *>()))
15191-
PreInits.append(PI->decl_begin(), PI->decl_end());
15192-
}
15193-
if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits))
15194-
PreInits.append(PI->decl_begin(), PI->decl_end());
15195-
// Gather declarations for the data members used as counters.
15196-
for (Expr *CounterRef : LoopHelper.Counters) {
15197-
auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl();
15198-
if (isa<OMPCapturedExprDecl>(CounterDecl))
15199-
PreInits.push_back(CounterDecl);
15200-
}
15279+
15280+
addLoopPreInits(Context, LoopHelper, LoopStmts[I], OriginalInits[I],
15281+
PreInits);
1520115282
}
1520215283

1520315284
// Once the original iteration values are set, append the innermost body.
@@ -15246,19 +15327,20 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1524615327
OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers[I];
1524715328
Expr *NumIterations = LoopHelper.NumIterations;
1524815329
auto *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters[0]);
15249-
QualType CntTy = OrigCntVar->getType();
15330+
QualType IVTy = NumIterations->getType();
15331+
Stmt *LoopStmt = LoopStmts[I];
1525015332

1525115333
// Commonly used variables. One of the constraints of an AST is that every
1525215334
// node object must appear at most once, hence we define lamdas that create
1525315335
// a new AST node at every use.
15254-
auto MakeTileIVRef = [&SemaRef = this->SemaRef, &TileIndVars, I, CntTy,
15336+
auto MakeTileIVRef = [&SemaRef = this->SemaRef, &TileIndVars, I, IVTy,
1525515337
OrigCntVar]() {
15256-
return buildDeclRefExpr(SemaRef, TileIndVars[I], CntTy,
15338+
return buildDeclRefExpr(SemaRef, TileIndVars[I], IVTy,
1525715339
OrigCntVar->getExprLoc());
1525815340
};
15259-
auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, CntTy,
15341+
auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, IVTy,
1526015342
OrigCntVar]() {
15261-
return buildDeclRefExpr(SemaRef, FloorIndVars[I], CntTy,
15343+
return buildDeclRefExpr(SemaRef, FloorIndVars[I], IVTy,
1526215344
OrigCntVar->getExprLoc());
1526315345
};
1526415346

@@ -15320,6 +15402,8 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1532015402
// further into the inner loop.
1532115403
SmallVector<Stmt *, 4> BodyParts;
1532215404
BodyParts.append(LoopHelper.Updates.begin(), LoopHelper.Updates.end());
15405+
if (auto *SourceCXXFor = dyn_cast<CXXForRangeStmt>(LoopStmt))
15406+
BodyParts.push_back(SourceCXXFor->getLoopVarStmt());
1532315407
BodyParts.push_back(Inner);
1532415408
Inner = CompoundStmt::Create(Context, BodyParts, FPOptionsOverride(),
1532515409
Inner->getBeginLoc(), Inner->getEndLoc());
@@ -15334,12 +15418,14 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
1533415418
auto &LoopHelper = LoopHelpers[I];
1533515419
Expr *NumIterations = LoopHelper.NumIterations;
1533615420
DeclRefExpr *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters[0]);
15337-
QualType CntTy = OrigCntVar->getType();
15421+
QualType IVTy = NumIterations->getType();
1533815422

15339-
// Commonly used variables.
15340-
auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, CntTy,
15423+
// Commonly used variables. One of the constraints of an AST is that every
15424+
// node object must appear at most once, hence we define lamdas that create
15425+
// a new AST node at every use.
15426+
auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, IVTy,
1534115427
OrigCntVar]() {
15342-
return buildDeclRefExpr(SemaRef, FloorIndVars[I], CntTy,
15428+
return buildDeclRefExpr(SemaRef, FloorIndVars[I], IVTy,
1534315429
OrigCntVar->getExprLoc());
1534415430
};
1534515431

@@ -15405,8 +15491,7 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
1540515491
Stmt *Body = nullptr;
1540615492
SmallVector<OMPLoopBasedDirective::HelperExprs, NumLoops> LoopHelpers(
1540715493
NumLoops);
15408-
SmallVector<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>, NumLoops + 1>
15409-
OriginalInits;
15494+
SmallVector<SmallVector<Stmt *, 0>, NumLoops + 1> OriginalInits;
1541015495
if (!checkTransformableLoopNest(OMPD_unroll, AStmt, NumLoops, LoopHelpers,
1541115496
Body, OriginalInits))
1541215497
return StmtError();
@@ -15418,6 +15503,10 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
1541815503
return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
1541915504
NumGeneratedLoops, nullptr, nullptr);
1542015505

15506+
assert(LoopHelpers.size() == NumLoops &&
15507+
"Expecting a single-dimensional loop iteration space");
15508+
assert(OriginalInits.size() == NumLoops &&
15509+
"Expecting a single-dimensional loop iteration space");
1542115510
OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front();
1542215511

1542315512
if (FullClause) {
@@ -15481,24 +15570,13 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
1548115570
// of a canonical loop nest where these PreInits are emitted before the
1548215571
// outermost directive.
1548315572

15573+
// Find the loop statement.
15574+
Stmt *LoopStmt = nullptr;
15575+
collectLoopStmts(AStmt, {LoopStmt});
15576+
1548415577
// Determine the PreInit declarations.
15485-
SmallVector<Decl *, 4> PreInits;
15486-
assert(OriginalInits.size() == 1 &&
15487-
"Expecting a single-dimensional loop iteration space");
15488-
for (auto &P : OriginalInits[0]) {
15489-
if (auto *D = P.dyn_cast<Decl *>())
15490-
PreInits.push_back(D);
15491-
else if (auto *PI = dyn_cast_or_null<DeclStmt>(P.dyn_cast<Stmt *>()))
15492-
PreInits.append(PI->decl_begin(), PI->decl_end());
15493-
}
15494-
if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits))
15495-
PreInits.append(PI->decl_begin(), PI->decl_end());
15496-
// Gather declarations for the data members used as counters.
15497-
for (Expr *CounterRef : LoopHelper.Counters) {
15498-
auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl();
15499-
if (isa<OMPCapturedExprDecl>(CounterDecl))
15500-
PreInits.push_back(CounterDecl);
15501-
}
15578+
SmallVector<Stmt *, 4> PreInits;
15579+
addLoopPreInits(Context, LoopHelper, LoopStmt, OriginalInits[0], PreInits);
1550215580

1550315581
auto *IterationVarRef = cast<DeclRefExpr>(LoopHelper.IterationVarRef);
1550415582
QualType IVTy = IterationVarRef->getType();
@@ -15604,6 +15682,8 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
1560415682
// Inner For statement.
1560515683
SmallVector<Stmt *> InnerBodyStmts;
1560615684
InnerBodyStmts.append(LoopHelper.Updates.begin(), LoopHelper.Updates.end());
15685+
if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt))
15686+
InnerBodyStmts.push_back(CXXRangeFor->getLoopVarStmt());
1560715687
InnerBodyStmts.push_back(Body);
1560815688
CompoundStmt *InnerBody =
1560915689
CompoundStmt::Create(getASTContext(), InnerBodyStmts, FPOptionsOverride(),

0 commit comments

Comments
 (0)