Skip to content

[Clang][OpenMP] Fix tile/unroll on iterator- and foreach-loops. #91459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions clang/include/clang/Sema/SemaOpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -1390,9 +1390,7 @@ class SemaOpenMP : public SemaBase {
bool checkTransformableLoopNest(
OpenMPDirectiveKind Kind, Stmt *AStmt, int NumLoops,
SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers,
Stmt *&Body,
SmallVectorImpl<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>>
&OriginalInits);
Stmt *&Body, SmallVectorImpl<SmallVector<Stmt *, 0>> &OriginalInits);

/// Helper to keep information about the current `omp begin/end declare
/// variant` nesting.
Expand Down
29 changes: 23 additions & 6 deletions clang/lib/CodeGen/CGStmtOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class OMPTeamsScope final : public OMPLexicalScope {
/// of used expression from loop statement.
class OMPLoopScope : public CodeGenFunction::RunCleanupsScope {
void emitPreInitStmt(CodeGenFunction &CGF, const OMPLoopBasedDirective &S) {
const DeclStmt *PreInits;
const Stmt *PreInits;
CodeGenFunction::OMPMapVars PreCondVars;
if (auto *LD = dyn_cast<OMPLoopDirective>(&S)) {
llvm::DenseSet<const VarDecl *> EmittedAsPrivate;
Expand Down Expand Up @@ -182,17 +182,34 @@ class OMPLoopScope : public CodeGenFunction::RunCleanupsScope {
}
return false;
});
PreInits = cast_or_null<DeclStmt>(LD->getPreInits());
PreInits = LD->getPreInits();
} else if (const auto *Tile = dyn_cast<OMPTileDirective>(&S)) {
PreInits = cast_or_null<DeclStmt>(Tile->getPreInits());
PreInits = Tile->getPreInits();
} else if (const auto *Unroll = dyn_cast<OMPUnrollDirective>(&S)) {
PreInits = cast_or_null<DeclStmt>(Unroll->getPreInits());
PreInits = Unroll->getPreInits();
} else {
llvm_unreachable("Unknown loop-based directive kind.");
}
if (PreInits) {
for (const auto *I : PreInits->decls())
CGF.EmitVarDecl(cast<VarDecl>(*I));
// CompoundStmts and DeclStmts are used as lists of PreInit statements and
// declarations. Since declarations must be visible in the the following
// that they initialize, unpack the ComboundStmt they are nested in.
SmallVector<const Stmt *> PreInitStmts;
if (auto *PreInitCompound = dyn_cast<CompoundStmt>(PreInits))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need to create coumpound stmt? DeclStmt support multiple declarations itself.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is explained in the summary. In essence, the init-statement for a C++20 foreach-loop does not need to be a DeclStmt, but can be an arbitrary Stmt.

llvm::append_range(PreInitStmts, PreInitCompound->body());
else
PreInitStmts.push_back(PreInits);

for (const Stmt *S : PreInitStmts) {
// EmitStmt skips any OMPCapturedExprDecls, but needs to be emitted
// here.
if (auto *PreInitDecl = dyn_cast<DeclStmt>(S)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, you still emit only DeclStmts? Then why you can't put all Decls into a single one DeclStmt upon creation in Sema?

Copy link
Member Author

@Meinersbur Meinersbur May 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything else is emitted in CGF.EmitStmt(S); at line 211.

CGF.EmitStmt(S) does itself call CGF.EmitVarDecl(S) if passed a VarDecl, so this special handling should not be necessary. It includes, however, an exception for OMPCapturedExprDecl (subclass of VarDecl) that are NOT emitted so we need to do this explicitly here. Otherwise, lines 203-212 would be just a single CGF.EmitStmt(S).

case Decl::OMPCapturedExpr:

for (Decl *I : PreInitDecl->decls())
CGF.EmitVarDecl(cast<VarDecl>(*I));
continue;
}
CGF.EmitStmt(S);
}
}
PreCondVars.restore(CGF);
}
Expand Down
197 changes: 140 additions & 57 deletions clang/lib/Sema/SemaOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9815,6 +9815,25 @@ static Stmt *buildPreInits(ASTContext &Context,
return nullptr;
}

/// Append the \p Item or the content of a CompoundStmt to the list \p
/// TargetList.
///
/// A CompoundStmt is used as container in case multiple statements need to be
/// stored in lieu of using an explicit list. Flattening is necessary because
/// contained DeclStmts need to be visible after the execution of the list. Used
/// for OpenMP pre-init declarations/statements.
static void appendFlattendedStmtList(SmallVectorImpl<Stmt *> &TargetList,
Stmt *Item) {
// nullptr represents an empty list.
if (!Item)
return;

if (auto *CS = dyn_cast<CompoundStmt>(Item))
llvm::append_range(TargetList, CS->body());
else
TargetList.push_back(Item);
}

/// Build preinits statement for the given declarations.
static Stmt *
buildPreInits(ASTContext &Context,
Expand All @@ -9828,6 +9847,17 @@ buildPreInits(ASTContext &Context,
return nullptr;
}

/// Build pre-init statement for the given statements.
static Stmt *buildPreInits(ASTContext &Context, ArrayRef<Stmt *> PreInits) {
if (PreInits.empty())
return nullptr;

SmallVector<Stmt *> Stmts;
for (Stmt *S : PreInits)
appendFlattendedStmtList(Stmts, S);
return CompoundStmt::Create(Context, PreInits, FPOptionsOverride(), {}, {});
}

/// Build postupdate expression for the given list of postupdates expressions.
static Expr *buildPostUpdate(Sema &S, ArrayRef<Expr *> PostUpdates) {
Expr *PostUpdate = nullptr;
Expand Down Expand Up @@ -9924,11 +9954,21 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr *CollapseLoopCountExpr,
Stmt *DependentPreInits = Transform->getPreInits();
if (!DependentPreInits)
return;
for (Decl *C : cast<DeclStmt>(DependentPreInits)->getDeclGroup()) {
auto *D = cast<VarDecl>(C);
DeclRefExpr *Ref = buildDeclRefExpr(SemaRef, D, D->getType(),
Transform->getBeginLoc());
Captures[Ref] = Ref;

// Search for pre-init declared variables that need to be captured
// to be referenceable inside the directive.
SmallVector<Stmt *> Constituents;
appendFlattendedStmtList(Constituents, DependentPreInits);
for (Stmt *S : Constituents) {
if (auto *DC = dyn_cast<DeclStmt>(S)) {
for (Decl *C : DC->decls()) {
auto *D = cast<VarDecl>(C);
DeclRefExpr *Ref = buildDeclRefExpr(
SemaRef, D, D->getType().getNonReferenceType(),
Transform->getBeginLoc());
Captures[Ref] = Ref;
}
}
}
}))
return 0;
Expand Down Expand Up @@ -15059,9 +15099,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeSimdDirective(
bool SemaOpenMP::checkTransformableLoopNest(
OpenMPDirectiveKind Kind, Stmt *AStmt, int NumLoops,
SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers,
Stmt *&Body,
SmallVectorImpl<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>>
&OriginalInits) {
Stmt *&Body, SmallVectorImpl<SmallVector<Stmt *, 0>> &OriginalInits) {
OriginalInits.emplace_back();
bool Result = OMPLoopBasedDirective::doForAllLoops(
AStmt->IgnoreContainers(), /*TryImperfectlyNestedLoops=*/false, NumLoops,
Expand Down Expand Up @@ -15095,16 +15133,70 @@ bool SemaOpenMP::checkTransformableLoopNest(
DependentPreInits = Dir->getPreInits();
else
llvm_unreachable("Unhandled loop transformation");
if (!DependentPreInits)
return;
llvm::append_range(OriginalInits.back(),
cast<DeclStmt>(DependentPreInits)->getDeclGroup());

appendFlattendedStmtList(OriginalInits.back(), DependentPreInits);
});
assert(OriginalInits.back().empty() && "No preinit after innermost loop");
OriginalInits.pop_back();
return Result;
}

/// Add preinit statements that need to be propageted from the selected loop.
static void addLoopPreInits(ASTContext &Context,
OMPLoopBasedDirective::HelperExprs &LoopHelper,
Stmt *LoopStmt, ArrayRef<Stmt *> OriginalInit,
SmallVectorImpl<Stmt *> &PreInits) {

// For range-based for-statements, ensure that their syntactic sugar is
// executed by adding them as pre-init statements.
if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt)) {
Stmt *RangeInit = CXXRangeFor->getInit();
if (RangeInit)
PreInits.push_back(RangeInit);

DeclStmt *RangeStmt = CXXRangeFor->getRangeStmt();
PreInits.push_back(new (Context) DeclStmt(RangeStmt->getDeclGroup(),
RangeStmt->getBeginLoc(),
RangeStmt->getEndLoc()));

DeclStmt *RangeEnd = CXXRangeFor->getEndStmt();
PreInits.push_back(new (Context) DeclStmt(RangeEnd->getDeclGroup(),
RangeEnd->getBeginLoc(),
RangeEnd->getEndLoc()));
}

llvm::append_range(PreInits, OriginalInit);

// List of OMPCapturedExprDecl, for __begin, __end, and NumIterations
if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits)) {
PreInits.push_back(new (Context) DeclStmt(
PI->getDeclGroup(), PI->getBeginLoc(), PI->getEndLoc()));
}

// Gather declarations for the data members used as counters.
for (Expr *CounterRef : LoopHelper.Counters) {
auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl();
if (isa<OMPCapturedExprDecl>(CounterDecl))
PreInits.push_back(new (Context) DeclStmt(
DeclGroupRef(CounterDecl), SourceLocation(), SourceLocation()));
}
}

/// Collect the loop statements (ForStmt or CXXRangeForStmt) of the affected
/// loop of a construct.
static void collectLoopStmts(Stmt *AStmt, MutableArrayRef<Stmt *> LoopStmts) {
size_t NumLoops = LoopStmts.size();
OMPLoopBasedDirective::doForAllLoops(
AStmt, /*TryImperfectlyNestedLoops=*/false, NumLoops,
[LoopStmts](unsigned Cnt, Stmt *CurStmt) {
assert(!LoopStmts[Cnt] && "Loop statement must not yet be assigned");
LoopStmts[Cnt] = CurStmt;
return false;
});
assert(!is_contained(LoopStmts, nullptr) &&
"Expecting a loop statement for each affected loop");
}

StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
Stmt *AStmt,
SourceLocation StartLoc,
Expand All @@ -15126,8 +15218,7 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
// Verify and diagnose loop nest.
SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers(NumLoops);
Stmt *Body = nullptr;
SmallVector<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>, 4>
OriginalInits;
SmallVector<SmallVector<Stmt *, 0>, 4> OriginalInits;
if (!checkTransformableLoopNest(OMPD_tile, AStmt, NumLoops, LoopHelpers, Body,
OriginalInits))
return StmtError();
Expand All @@ -15144,7 +15235,11 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
"Expecting loop iteration space dimensionality to match number of "
"affected loops");

SmallVector<Decl *, 4> PreInits;
// Collect all affected loop statements.
SmallVector<Stmt *> LoopStmts(NumLoops, nullptr);
collectLoopStmts(AStmt, LoopStmts);

SmallVector<Stmt *, 4> PreInits;
CaptureVars CopyTransformer(SemaRef);

// Create iteration variables for the generated loops.
Expand Down Expand Up @@ -15184,20 +15279,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
&SemaRef.PP.getIdentifierTable().get(TileCntName));
TileIndVars[I] = TileCntDecl;
}
for (auto &P : OriginalInits[I]) {
if (auto *D = P.dyn_cast<Decl *>())
PreInits.push_back(D);
else if (auto *PI = dyn_cast_or_null<DeclStmt>(P.dyn_cast<Stmt *>()))
PreInits.append(PI->decl_begin(), PI->decl_end());
}
if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits))
PreInits.append(PI->decl_begin(), PI->decl_end());
// Gather declarations for the data members used as counters.
for (Expr *CounterRef : LoopHelper.Counters) {
auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl();
if (isa<OMPCapturedExprDecl>(CounterDecl))
PreInits.push_back(CounterDecl);
}

addLoopPreInits(Context, LoopHelper, LoopStmts[I], OriginalInits[I],
PreInits);
}

// Once the original iteration values are set, append the innermost body.
Expand Down Expand Up @@ -15246,19 +15330,20 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers[I];
Expr *NumIterations = LoopHelper.NumIterations;
auto *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters[0]);
QualType CntTy = OrigCntVar->getType();
QualType IVTy = NumIterations->getType();
Stmt *LoopStmt = LoopStmts[I];

// Commonly used variables. One of the constraints of an AST is that every
// node object must appear at most once, hence we define lamdas that create
// a new AST node at every use.
auto MakeTileIVRef = [&SemaRef = this->SemaRef, &TileIndVars, I, CntTy,
auto MakeTileIVRef = [&SemaRef = this->SemaRef, &TileIndVars, I, IVTy,
OrigCntVar]() {
return buildDeclRefExpr(SemaRef, TileIndVars[I], CntTy,
return buildDeclRefExpr(SemaRef, TileIndVars[I], IVTy,
OrigCntVar->getExprLoc());
};
auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, CntTy,
auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, IVTy,
OrigCntVar]() {
return buildDeclRefExpr(SemaRef, FloorIndVars[I], CntTy,
return buildDeclRefExpr(SemaRef, FloorIndVars[I], IVTy,
OrigCntVar->getExprLoc());
};

Expand Down Expand Up @@ -15320,6 +15405,8 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
// further into the inner loop.
SmallVector<Stmt *, 4> BodyParts;
BodyParts.append(LoopHelper.Updates.begin(), LoopHelper.Updates.end());
if (auto *SourceCXXFor = dyn_cast<CXXForRangeStmt>(LoopStmt))
BodyParts.push_back(SourceCXXFor->getLoopVarStmt());
BodyParts.push_back(Inner);
Inner = CompoundStmt::Create(Context, BodyParts, FPOptionsOverride(),
Inner->getBeginLoc(), Inner->getEndLoc());
Expand All @@ -15334,12 +15421,14 @@ StmtResult SemaOpenMP::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses,
auto &LoopHelper = LoopHelpers[I];
Expr *NumIterations = LoopHelper.NumIterations;
DeclRefExpr *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters[0]);
QualType CntTy = OrigCntVar->getType();
QualType IVTy = NumIterations->getType();

// Commonly used variables.
auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, CntTy,
// Commonly used variables. One of the constraints of an AST is that every
// node object must appear at most once, hence we define lamdas that create
// a new AST node at every use.
auto MakeFloorIVRef = [&SemaRef = this->SemaRef, &FloorIndVars, I, IVTy,
OrigCntVar]() {
return buildDeclRefExpr(SemaRef, FloorIndVars[I], CntTy,
return buildDeclRefExpr(SemaRef, FloorIndVars[I], IVTy,
OrigCntVar->getExprLoc());
};

Expand Down Expand Up @@ -15405,8 +15494,7 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
Stmt *Body = nullptr;
SmallVector<OMPLoopBasedDirective::HelperExprs, NumLoops> LoopHelpers(
NumLoops);
SmallVector<SmallVector<llvm::PointerUnion<Stmt *, Decl *>, 0>, NumLoops + 1>
OriginalInits;
SmallVector<SmallVector<Stmt *, 0>, NumLoops + 1> OriginalInits;
if (!checkTransformableLoopNest(OMPD_unroll, AStmt, NumLoops, LoopHelpers,
Body, OriginalInits))
return StmtError();
Expand All @@ -15418,6 +15506,10 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
NumGeneratedLoops, nullptr, nullptr);

assert(LoopHelpers.size() == NumLoops &&
"Expecting a single-dimensional loop iteration space");
assert(OriginalInits.size() == NumLoops &&
"Expecting a single-dimensional loop iteration space");
OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front();

if (FullClause) {
Expand Down Expand Up @@ -15481,24 +15573,13 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
// of a canonical loop nest where these PreInits are emitted before the
// outermost directive.

// Find the loop statement.
Stmt *LoopStmt = nullptr;
collectLoopStmts(AStmt, {LoopStmt});

// Determine the PreInit declarations.
SmallVector<Decl *, 4> PreInits;
assert(OriginalInits.size() == 1 &&
"Expecting a single-dimensional loop iteration space");
for (auto &P : OriginalInits[0]) {
if (auto *D = P.dyn_cast<Decl *>())
PreInits.push_back(D);
else if (auto *PI = dyn_cast_or_null<DeclStmt>(P.dyn_cast<Stmt *>()))
PreInits.append(PI->decl_begin(), PI->decl_end());
}
if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits))
PreInits.append(PI->decl_begin(), PI->decl_end());
// Gather declarations for the data members used as counters.
for (Expr *CounterRef : LoopHelper.Counters) {
auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl();
if (isa<OMPCapturedExprDecl>(CounterDecl))
PreInits.push_back(CounterDecl);
}
SmallVector<Stmt *, 4> PreInits;
addLoopPreInits(Context, LoopHelper, LoopStmt, OriginalInits[0], PreInits);

auto *IterationVarRef = cast<DeclRefExpr>(LoopHelper.IterationVarRef);
QualType IVTy = IterationVarRef->getType();
Expand Down Expand Up @@ -15604,6 +15685,8 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
// Inner For statement.
SmallVector<Stmt *> InnerBodyStmts;
InnerBodyStmts.append(LoopHelper.Updates.begin(), LoopHelper.Updates.end());
if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt))
InnerBodyStmts.push_back(CXXRangeFor->getLoopVarStmt());
InnerBodyStmts.push_back(Body);
CompoundStmt *InnerBody =
CompoundStmt::Create(getASTContext(), InnerBodyStmts, FPOptionsOverride(),
Expand Down
Loading
Loading