21
21
#include "clang/AST/DeclCXX.h"
22
22
#include "clang/AST/DeclOpenMP.h"
23
23
#include "clang/AST/OpenMPClause.h"
24
+ #include "clang/AST/RecursiveASTVisitor.h"
24
25
#include "clang/AST/StmtCXX.h"
25
26
#include "clang/AST/StmtOpenMP.h"
26
27
#include "clang/AST/StmtVisitor.h"
@@ -7668,6 +7669,52 @@ struct LoopIterationSpace final {
7668
7669
Expr *FinalCondition = nullptr;
7669
7670
};
7670
7671
7672
+ /// Scan an AST subtree, checking that no decls in the CollapsedLoopVarDecls
7673
+ /// set are referenced. Used for verifying loop nest structure before
7674
+ /// performing a loop collapse operation.
7675
+ class ForSubExprChecker final : public RecursiveASTVisitor<ForSubExprChecker> {
7676
+ const llvm::SmallPtrSetImpl<const Decl *> &CollapsedLoopVarDecls;
7677
+ VarDecl *ForbiddenVar = nullptr;
7678
+ SourceRange ErrLoc;
7679
+
7680
+ public:
7681
+ explicit ForSubExprChecker(
7682
+ const llvm::SmallPtrSetImpl<const Decl *> &CollapsedLoopVarDecls)
7683
+ : CollapsedLoopVarDecls(CollapsedLoopVarDecls) {}
7684
+
7685
+ // We want to visit implicit code, i.e. synthetic initialisation statements
7686
+ // created during range-for lowering.
7687
+ bool shouldVisitImplicitCode() const { return true; }
7688
+
7689
+ bool VisitDeclRefExpr(DeclRefExpr *E) {
7690
+ ValueDecl *VD = E->getDecl();
7691
+ if (!isa<VarDecl, BindingDecl>(VD))
7692
+ return true;
7693
+ VarDecl *V = VD->getPotentiallyDecomposedVarDecl();
7694
+ if (V->getType()->isReferenceType()) {
7695
+ VarDecl *VD = V->getDefinition();
7696
+ if (VD->hasInit()) {
7697
+ Expr *I = VD->getInit();
7698
+ DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(I);
7699
+ if (!DRE)
7700
+ return true;
7701
+ V = DRE->getDecl()->getPotentiallyDecomposedVarDecl();
7702
+ }
7703
+ }
7704
+ Decl *Canon = V->getCanonicalDecl();
7705
+ if (CollapsedLoopVarDecls.contains(Canon)) {
7706
+ ForbiddenVar = V;
7707
+ ErrLoc = E->getSourceRange();
7708
+ return false;
7709
+ }
7710
+
7711
+ return true;
7712
+ }
7713
+
7714
+ VarDecl *getForbiddenVar() const { return ForbiddenVar; }
7715
+ SourceRange getErrRange() const { return ErrLoc; }
7716
+ };
7717
+
7671
7718
/// Helper class for checking canonical form of the OpenMP loops and
7672
7719
/// extracting iteration space of each loop in the loop nest, that will be used
7673
7720
/// for IR generation.
@@ -7682,6 +7729,8 @@ class OpenMPIterationSpaceChecker {
7682
7729
SourceLocation DefaultLoc;
7683
7730
/// A location for diagnostics (when increment is not compatible).
7684
7731
SourceLocation ConditionLoc;
7732
+ /// The set of variables declared within the (to be collapsed) loop nest.
7733
+ const llvm::SmallPtrSetImpl<const Decl *> &CollapsedLoopVarDecls;
7685
7734
/// A source location for referring to loop init later.
7686
7735
SourceRange InitSrcRange;
7687
7736
/// A source location for referring to condition later.
@@ -7725,10 +7774,13 @@ class OpenMPIterationSpaceChecker {
7725
7774
Expr *Condition = nullptr;
7726
7775
7727
7776
public:
7728
- OpenMPIterationSpaceChecker(Sema &SemaRef, bool SupportsNonRectangular,
7729
- DSAStackTy &Stack, SourceLocation DefaultLoc)
7777
+ OpenMPIterationSpaceChecker(
7778
+ Sema &SemaRef, bool SupportsNonRectangular, DSAStackTy &Stack,
7779
+ SourceLocation DefaultLoc,
7780
+ const llvm::SmallPtrSetImpl<const Decl *> &CollapsedLoopDecls)
7730
7781
: SemaRef(SemaRef), SupportsNonRectangular(SupportsNonRectangular),
7731
- Stack(Stack), DefaultLoc(DefaultLoc), ConditionLoc(DefaultLoc) {}
7782
+ Stack(Stack), DefaultLoc(DefaultLoc), ConditionLoc(DefaultLoc),
7783
+ CollapsedLoopVarDecls(CollapsedLoopDecls) {}
7732
7784
/// Check init-expr for canonical loop form and save loop counter
7733
7785
/// variable - #Var and its initialization value - #LB.
7734
7786
bool checkAndSetInit(Stmt *S, bool EmitDiags = true);
@@ -8049,6 +8101,16 @@ bool OpenMPIterationSpaceChecker::checkAndSetInit(Stmt *S, bool EmitDiags) {
8049
8101
if (!ExprTemp->cleanupsHaveSideEffects())
8050
8102
S = ExprTemp->getSubExpr();
8051
8103
8104
+ if (!CollapsedLoopVarDecls.empty()) {
8105
+ ForSubExprChecker FSEC{CollapsedLoopVarDecls};
8106
+ if (!FSEC.TraverseStmt(S)) {
8107
+ SourceRange Range = FSEC.getErrRange();
8108
+ SemaRef.Diag(Range.getBegin(), diag::err_omp_loop_bad_collapse_var)
8109
+ << Range.getEnd() << 0 << FSEC.getForbiddenVar();
8110
+ return true;
8111
+ }
8112
+ }
8113
+
8052
8114
InitSrcRange = S->getSourceRange();
8053
8115
if (Expr *E = dyn_cast<Expr>(S))
8054
8116
S = E->IgnoreParens();
@@ -8152,6 +8214,17 @@ bool OpenMPIterationSpaceChecker::checkAndSetCond(Expr *S) {
8152
8214
}
8153
8215
Condition = S;
8154
8216
S = getExprAsWritten(S);
8217
+
8218
+ if (!CollapsedLoopVarDecls.empty()) {
8219
+ ForSubExprChecker FSEC{CollapsedLoopVarDecls};
8220
+ if (!FSEC.TraverseStmt(S)) {
8221
+ SourceRange Range = FSEC.getErrRange();
8222
+ SemaRef.Diag(Range.getBegin(), diag::err_omp_loop_bad_collapse_var)
8223
+ << Range.getEnd() << 1 << FSEC.getForbiddenVar();
8224
+ return true;
8225
+ }
8226
+ }
8227
+
8155
8228
SourceLocation CondLoc = S->getBeginLoc();
8156
8229
auto &&CheckAndSetCond =
8157
8230
[this, IneqCondIsCanonical](BinaryOperatorKind Opcode, const Expr *LHS,
@@ -8250,6 +8323,16 @@ bool OpenMPIterationSpaceChecker::checkAndSetInc(Expr *S) {
8250
8323
if (!ExprTemp->cleanupsHaveSideEffects())
8251
8324
S = ExprTemp->getSubExpr();
8252
8325
8326
+ if (!CollapsedLoopVarDecls.empty()) {
8327
+ ForSubExprChecker FSEC{CollapsedLoopVarDecls};
8328
+ if (!FSEC.TraverseStmt(S)) {
8329
+ SourceRange Range = FSEC.getErrRange();
8330
+ SemaRef.Diag(Range.getBegin(), diag::err_omp_loop_bad_collapse_var)
8331
+ << Range.getEnd() << 2 << FSEC.getForbiddenVar();
8332
+ return true;
8333
+ }
8334
+ }
8335
+
8253
8336
IncrementSrcRange = S->getSourceRange();
8254
8337
S = S->IgnoreParens();
8255
8338
if (auto *UO = dyn_cast<UnaryOperator>(S)) {
@@ -8971,8 +9054,9 @@ void SemaOpenMP::ActOnOpenMPLoopInitialization(SourceLocation ForLoc,
8971
9054
return;
8972
9055
8973
9056
DSAStack->loopStart();
9057
+ llvm::SmallPtrSet<const Decl *, 1> EmptyDeclSet{};
8974
9058
OpenMPIterationSpaceChecker ISC(SemaRef, /*SupportsNonRectangular=*/true,
8975
- *DSAStack, ForLoc);
9059
+ *DSAStack, ForLoc, EmptyDeclSet );
8976
9060
if (!ISC.checkAndSetInit(Init, /*EmitDiags=*/false)) {
8977
9061
if (ValueDecl *D = ISC.getLoopDecl()) {
8978
9062
auto *VD = dyn_cast<VarDecl>(D);
@@ -9069,7 +9153,8 @@ static bool checkOpenMPIterationSpace(
9069
9153
Expr *OrderedLoopCountExpr,
9070
9154
SemaOpenMP::VarsWithInheritedDSAType &VarsWithImplicitDSA,
9071
9155
llvm::MutableArrayRef<LoopIterationSpace> ResultIterSpaces,
9072
- llvm::MapVector<const Expr *, DeclRefExpr *> &Captures) {
9156
+ llvm::MapVector<const Expr *, DeclRefExpr *> &Captures,
9157
+ const llvm::SmallPtrSetImpl<const Decl *> &CollapsedLoopVarDecls) {
9073
9158
bool SupportsNonRectangular = !isOpenMPLoopTransformationDirective(DKind);
9074
9159
// OpenMP [2.9.1, Canonical Loop Form]
9075
9160
// for (init-expr; test-expr; incr-expr) structured-block
@@ -9108,7 +9193,8 @@ static bool checkOpenMPIterationSpace(
9108
9193
return false;
9109
9194
9110
9195
OpenMPIterationSpaceChecker ISC(SemaRef, SupportsNonRectangular, DSA,
9111
- For ? For->getForLoc() : CXXFor->getForLoc());
9196
+ For ? For->getForLoc() : CXXFor->getForLoc(),
9197
+ CollapsedLoopVarDecls);
9112
9198
9113
9199
// Check init.
9114
9200
Stmt *Init = For ? For->getInit() : CXXFor->getBeginStmt();
@@ -9475,6 +9561,39 @@ static Expr *buildPostUpdate(Sema &S, ArrayRef<Expr *> PostUpdates) {
9475
9561
return PostUpdate;
9476
9562
}
9477
9563
9564
+ /// Look for variables declared in the body parts of a for-loop nest. Used
9565
+ /// for verifying loop nest structure before performing a loop collapse
9566
+ /// operation.
9567
+ class ForVarDeclFinder final : public RecursiveASTVisitor<ForVarDeclFinder> {
9568
+ int NestingDepth = 0;
9569
+ llvm::SmallPtrSetImpl<const Decl *> &VarDecls;
9570
+
9571
+ public:
9572
+ explicit ForVarDeclFinder(llvm::SmallPtrSetImpl<const Decl *> &VD)
9573
+ : VarDecls(VD) {}
9574
+
9575
+ bool VisitForStmt(ForStmt *F) {
9576
+ ++NestingDepth;
9577
+ TraverseStmt(F->getBody());
9578
+ --NestingDepth;
9579
+ return false;
9580
+ }
9581
+
9582
+ bool VisitCXXForRangeStmt(CXXForRangeStmt *RF) {
9583
+ ++NestingDepth;
9584
+ TraverseStmt(RF->getBody());
9585
+ --NestingDepth;
9586
+ return false;
9587
+ }
9588
+
9589
+ bool VisitVarDecl(VarDecl *D) {
9590
+ Decl *C = D->getCanonicalDecl();
9591
+ if (NestingDepth > 0)
9592
+ VarDecls.insert(C);
9593
+ return true;
9594
+ }
9595
+ };
9596
+
9478
9597
/// Called on a for stmt to check itself and nested loops (if any).
9479
9598
/// \return Returns 0 if one of the collapsed stmts is not canonical for loop,
9480
9599
/// number of collapsed loops otherwise.
@@ -9487,13 +9606,17 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr *CollapseLoopCountExpr,
9487
9606
unsigned NestedLoopCount = 1;
9488
9607
bool SupportsNonPerfectlyNested = (SemaRef.LangOpts.OpenMP >= 50) &&
9489
9608
!isOpenMPLoopTransformationDirective(DKind);
9609
+ llvm::SmallPtrSet<const Decl *, 4> CollapsedLoopVarDecls{};
9490
9610
9491
9611
if (CollapseLoopCountExpr) {
9492
9612
// Found 'collapse' clause - calculate collapse number.
9493
9613
Expr::EvalResult Result;
9494
9614
if (!CollapseLoopCountExpr->isValueDependent() &&
9495
9615
CollapseLoopCountExpr->EvaluateAsInt(Result, SemaRef.getASTContext())) {
9496
9616
NestedLoopCount = Result.Val.getInt().getLimitedValue();
9617
+
9618
+ ForVarDeclFinder FVDF{CollapsedLoopVarDecls};
9619
+ FVDF.TraverseStmt(AStmt);
9497
9620
} else {
9498
9621
Built.clear(/*Size=*/1);
9499
9622
return 1;
@@ -9531,11 +9654,13 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr *CollapseLoopCountExpr,
9531
9654
SupportsNonPerfectlyNested, NumLoops,
9532
9655
[DKind, &SemaRef, &DSA, NumLoops, NestedLoopCount,
9533
9656
CollapseLoopCountExpr, OrderedLoopCountExpr, &VarsWithImplicitDSA,
9534
- &IterSpaces, &Captures](unsigned Cnt, Stmt *CurStmt) {
9657
+ &IterSpaces, &Captures,
9658
+ &CollapsedLoopVarDecls](unsigned Cnt, Stmt *CurStmt) {
9535
9659
if (checkOpenMPIterationSpace(
9536
9660
DKind, CurStmt, SemaRef, DSA, Cnt, NestedLoopCount,
9537
9661
NumLoops, CollapseLoopCountExpr, OrderedLoopCountExpr,
9538
- VarsWithImplicitDSA, IterSpaces, Captures))
9662
+ VarsWithImplicitDSA, IterSpaces, Captures,
9663
+ CollapsedLoopVarDecls))
9539
9664
return true;
9540
9665
if (Cnt > 0 && Cnt >= NestedLoopCount &&
9541
9666
IterSpaces[Cnt].CounterVar) {
0 commit comments