21
21
#include "clang/AST/DeclCXX.h"
22
22
#include "clang/AST/DeclOpenMP.h"
23
23
#include "clang/AST/OpenMPClause.h"
24
+ #include "clang/AST/RecursiveASTVisitor.h"
24
25
#include "clang/AST/StmtCXX.h"
25
26
#include "clang/AST/StmtOpenMP.h"
26
27
#include "clang/AST/StmtVisitor.h"
@@ -7695,6 +7696,52 @@ struct LoopIterationSpace final {
7695
7696
Expr *FinalCondition = nullptr;
7696
7697
};
7697
7698
7699
+ /// Scan an AST subtree, checking that no decls in the CollapsedLoopVarDecls
7700
+ /// set are referenced. Used for verifying loop nest structure before
7701
+ /// performing a loop collapse operation.
7702
+ class ForSubExprChecker final : public RecursiveASTVisitor<ForSubExprChecker> {
7703
+ const llvm::SmallPtrSetImpl<const Decl *> &CollapsedLoopVarDecls;
7704
+ VarDecl *ForbiddenVar = nullptr;
7705
+ SourceRange ErrLoc;
7706
+
7707
+ public:
7708
+ explicit ForSubExprChecker(
7709
+ const llvm::SmallPtrSetImpl<const Decl *> &CollapsedLoopVarDecls)
7710
+ : CollapsedLoopVarDecls(CollapsedLoopVarDecls) {}
7711
+
7712
+ // We want to visit implicit code, i.e. synthetic initialisation statements
7713
+ // created during range-for lowering.
7714
+ bool shouldVisitImplicitCode() const { return true; }
7715
+
7716
+ bool VisitDeclRefExpr(DeclRefExpr *E) {
7717
+ ValueDecl *VD = E->getDecl();
7718
+ if (!isa<VarDecl, BindingDecl>(VD))
7719
+ return true;
7720
+ VarDecl *V = VD->getPotentiallyDecomposedVarDecl();
7721
+ if (V->getType()->isReferenceType()) {
7722
+ VarDecl *VD = V->getDefinition();
7723
+ if (VD->hasInit()) {
7724
+ Expr *I = VD->getInit();
7725
+ DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(I);
7726
+ if (!DRE)
7727
+ return true;
7728
+ V = DRE->getDecl()->getPotentiallyDecomposedVarDecl();
7729
+ }
7730
+ }
7731
+ Decl *Canon = V->getCanonicalDecl();
7732
+ if (CollapsedLoopVarDecls.contains(Canon)) {
7733
+ ForbiddenVar = V;
7734
+ ErrLoc = E->getSourceRange();
7735
+ return false;
7736
+ }
7737
+
7738
+ return true;
7739
+ }
7740
+
7741
+ VarDecl *getForbiddenVar() const { return ForbiddenVar; }
7742
+ SourceRange getErrRange() const { return ErrLoc; }
7743
+ };
7744
+
7698
7745
/// Helper class for checking canonical form of the OpenMP loops and
7699
7746
/// extracting iteration space of each loop in the loop nest, that will be used
7700
7747
/// for IR generation.
@@ -7709,6 +7756,8 @@ class OpenMPIterationSpaceChecker {
7709
7756
SourceLocation DefaultLoc;
7710
7757
/// A location for diagnostics (when increment is not compatible).
7711
7758
SourceLocation ConditionLoc;
7759
+ /// The set of variables declared within the (to be collapsed) loop nest.
7760
+ const llvm::SmallPtrSetImpl<const Decl *> &CollapsedLoopVarDecls;
7712
7761
/// A source location for referring to loop init later.
7713
7762
SourceRange InitSrcRange;
7714
7763
/// A source location for referring to condition later.
@@ -7752,10 +7801,13 @@ class OpenMPIterationSpaceChecker {
7752
7801
Expr *Condition = nullptr;
7753
7802
7754
7803
public:
7755
- OpenMPIterationSpaceChecker(Sema &SemaRef, bool SupportsNonRectangular,
7756
- DSAStackTy &Stack, SourceLocation DefaultLoc)
7804
+ OpenMPIterationSpaceChecker(
7805
+ Sema &SemaRef, bool SupportsNonRectangular, DSAStackTy &Stack,
7806
+ SourceLocation DefaultLoc,
7807
+ const llvm::SmallPtrSetImpl<const Decl *> &CollapsedLoopDecls)
7757
7808
: SemaRef(SemaRef), SupportsNonRectangular(SupportsNonRectangular),
7758
- Stack(Stack), DefaultLoc(DefaultLoc), ConditionLoc(DefaultLoc) {}
7809
+ Stack(Stack), DefaultLoc(DefaultLoc), ConditionLoc(DefaultLoc),
7810
+ CollapsedLoopVarDecls(CollapsedLoopDecls) {}
7759
7811
/// Check init-expr for canonical loop form and save loop counter
7760
7812
/// variable - #Var and its initialization value - #LB.
7761
7813
bool checkAndSetInit(Stmt *S, bool EmitDiags = true);
@@ -8076,6 +8128,16 @@ bool OpenMPIterationSpaceChecker::checkAndSetInit(Stmt *S, bool EmitDiags) {
8076
8128
if (!ExprTemp->cleanupsHaveSideEffects())
8077
8129
S = ExprTemp->getSubExpr();
8078
8130
8131
+ if (!CollapsedLoopVarDecls.empty()) {
8132
+ ForSubExprChecker FSEC{CollapsedLoopVarDecls};
8133
+ if (!FSEC.TraverseStmt(S)) {
8134
+ SourceRange Range = FSEC.getErrRange();
8135
+ SemaRef.Diag(Range.getBegin(), diag::err_omp_loop_bad_collapse_var)
8136
+ << Range.getEnd() << 0 << FSEC.getForbiddenVar();
8137
+ return true;
8138
+ }
8139
+ }
8140
+
8079
8141
InitSrcRange = S->getSourceRange();
8080
8142
if (Expr *E = dyn_cast<Expr>(S))
8081
8143
S = E->IgnoreParens();
@@ -8179,6 +8241,17 @@ bool OpenMPIterationSpaceChecker::checkAndSetCond(Expr *S) {
8179
8241
}
8180
8242
Condition = S;
8181
8243
S = getExprAsWritten(S);
8244
+
8245
+ if (!CollapsedLoopVarDecls.empty()) {
8246
+ ForSubExprChecker FSEC{CollapsedLoopVarDecls};
8247
+ if (!FSEC.TraverseStmt(S)) {
8248
+ SourceRange Range = FSEC.getErrRange();
8249
+ SemaRef.Diag(Range.getBegin(), diag::err_omp_loop_bad_collapse_var)
8250
+ << Range.getEnd() << 1 << FSEC.getForbiddenVar();
8251
+ return true;
8252
+ }
8253
+ }
8254
+
8182
8255
SourceLocation CondLoc = S->getBeginLoc();
8183
8256
auto &&CheckAndSetCond =
8184
8257
[this, IneqCondIsCanonical](BinaryOperatorKind Opcode, const Expr *LHS,
@@ -8277,6 +8350,16 @@ bool OpenMPIterationSpaceChecker::checkAndSetInc(Expr *S) {
8277
8350
if (!ExprTemp->cleanupsHaveSideEffects())
8278
8351
S = ExprTemp->getSubExpr();
8279
8352
8353
+ if (!CollapsedLoopVarDecls.empty()) {
8354
+ ForSubExprChecker FSEC{CollapsedLoopVarDecls};
8355
+ if (!FSEC.TraverseStmt(S)) {
8356
+ SourceRange Range = FSEC.getErrRange();
8357
+ SemaRef.Diag(Range.getBegin(), diag::err_omp_loop_bad_collapse_var)
8358
+ << Range.getEnd() << 2 << FSEC.getForbiddenVar();
8359
+ return true;
8360
+ }
8361
+ }
8362
+
8280
8363
IncrementSrcRange = S->getSourceRange();
8281
8364
S = S->IgnoreParens();
8282
8365
if (auto *UO = dyn_cast<UnaryOperator>(S)) {
@@ -8998,8 +9081,9 @@ void SemaOpenMP::ActOnOpenMPLoopInitialization(SourceLocation ForLoc,
8998
9081
return;
8999
9082
9000
9083
DSAStack->loopStart();
9084
+ llvm::SmallPtrSet<const Decl *, 1> EmptyDeclSet;
9001
9085
OpenMPIterationSpaceChecker ISC(SemaRef, /*SupportsNonRectangular=*/true,
9002
- *DSAStack, ForLoc);
9086
+ *DSAStack, ForLoc, EmptyDeclSet );
9003
9087
if (!ISC.checkAndSetInit(Init, /*EmitDiags=*/false)) {
9004
9088
if (ValueDecl *D = ISC.getLoopDecl()) {
9005
9089
auto *VD = dyn_cast<VarDecl>(D);
@@ -9096,7 +9180,8 @@ static bool checkOpenMPIterationSpace(
9096
9180
Expr *OrderedLoopCountExpr,
9097
9181
SemaOpenMP::VarsWithInheritedDSAType &VarsWithImplicitDSA,
9098
9182
llvm::MutableArrayRef<LoopIterationSpace> ResultIterSpaces,
9099
- llvm::MapVector<const Expr *, DeclRefExpr *> &Captures) {
9183
+ llvm::MapVector<const Expr *, DeclRefExpr *> &Captures,
9184
+ const llvm::SmallPtrSetImpl<const Decl *> &CollapsedLoopVarDecls) {
9100
9185
bool SupportsNonRectangular = !isOpenMPLoopTransformationDirective(DKind);
9101
9186
// OpenMP [2.9.1, Canonical Loop Form]
9102
9187
// for (init-expr; test-expr; incr-expr) structured-block
@@ -9135,7 +9220,8 @@ static bool checkOpenMPIterationSpace(
9135
9220
return false;
9136
9221
9137
9222
OpenMPIterationSpaceChecker ISC(SemaRef, SupportsNonRectangular, DSA,
9138
- For ? For->getForLoc() : CXXFor->getForLoc());
9223
+ For ? For->getForLoc() : CXXFor->getForLoc(),
9224
+ CollapsedLoopVarDecls);
9139
9225
9140
9226
// Check init.
9141
9227
Stmt *Init = For ? For->getInit() : CXXFor->getBeginStmt();
@@ -9502,6 +9588,39 @@ static Expr *buildPostUpdate(Sema &S, ArrayRef<Expr *> PostUpdates) {
9502
9588
return PostUpdate;
9503
9589
}
9504
9590
9591
+ /// Look for variables declared in the body parts of a for-loop nest. Used
9592
+ /// for verifying loop nest structure before performing a loop collapse
9593
+ /// operation.
9594
+ class ForVarDeclFinder final : public RecursiveASTVisitor<ForVarDeclFinder> {
9595
+ int NestingDepth = 0;
9596
+ llvm::SmallPtrSetImpl<const Decl *> &VarDecls;
9597
+
9598
+ public:
9599
+ explicit ForVarDeclFinder(llvm::SmallPtrSetImpl<const Decl *> &VD)
9600
+ : VarDecls(VD) {}
9601
+
9602
+ bool VisitForStmt(ForStmt *F) {
9603
+ ++NestingDepth;
9604
+ TraverseStmt(F->getBody());
9605
+ --NestingDepth;
9606
+ return false;
9607
+ }
9608
+
9609
+ bool VisitCXXForRangeStmt(CXXForRangeStmt *RF) {
9610
+ ++NestingDepth;
9611
+ TraverseStmt(RF->getBody());
9612
+ --NestingDepth;
9613
+ return false;
9614
+ }
9615
+
9616
+ bool VisitVarDecl(VarDecl *D) {
9617
+ Decl *C = D->getCanonicalDecl();
9618
+ if (NestingDepth > 0)
9619
+ VarDecls.insert(C);
9620
+ return true;
9621
+ }
9622
+ };
9623
+
9505
9624
/// Called on a for stmt to check itself and nested loops (if any).
9506
9625
/// \return Returns 0 if one of the collapsed stmts is not canonical for loop,
9507
9626
/// number of collapsed loops otherwise.
@@ -9514,13 +9633,17 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr *CollapseLoopCountExpr,
9514
9633
unsigned NestedLoopCount = 1;
9515
9634
bool SupportsNonPerfectlyNested = (SemaRef.LangOpts.OpenMP >= 50) &&
9516
9635
!isOpenMPLoopTransformationDirective(DKind);
9636
+ llvm::SmallPtrSet<const Decl *, 4> CollapsedLoopVarDecls;
9517
9637
9518
9638
if (CollapseLoopCountExpr) {
9519
9639
// Found 'collapse' clause - calculate collapse number.
9520
9640
Expr::EvalResult Result;
9521
9641
if (!CollapseLoopCountExpr->isValueDependent() &&
9522
9642
CollapseLoopCountExpr->EvaluateAsInt(Result, SemaRef.getASTContext())) {
9523
9643
NestedLoopCount = Result.Val.getInt().getLimitedValue();
9644
+
9645
+ ForVarDeclFinder FVDF{CollapsedLoopVarDecls};
9646
+ FVDF.TraverseStmt(AStmt);
9524
9647
} else {
9525
9648
Built.clear(/*Size=*/1);
9526
9649
return 1;
@@ -9558,11 +9681,13 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr *CollapseLoopCountExpr,
9558
9681
SupportsNonPerfectlyNested, NumLoops,
9559
9682
[DKind, &SemaRef, &DSA, NumLoops, NestedLoopCount,
9560
9683
CollapseLoopCountExpr, OrderedLoopCountExpr, &VarsWithImplicitDSA,
9561
- &IterSpaces, &Captures](unsigned Cnt, Stmt *CurStmt) {
9684
+ &IterSpaces, &Captures,
9685
+ &CollapsedLoopVarDecls](unsigned Cnt, Stmt *CurStmt) {
9562
9686
if (checkOpenMPIterationSpace(
9563
9687
DKind, CurStmt, SemaRef, DSA, Cnt, NestedLoopCount,
9564
9688
NumLoops, CollapseLoopCountExpr, OrderedLoopCountExpr,
9565
- VarsWithImplicitDSA, IterSpaces, Captures))
9689
+ VarsWithImplicitDSA, IterSpaces, Captures,
9690
+ CollapsedLoopVarDecls))
9566
9691
return true;
9567
9692
if (Cnt > 0 && Cnt >= NestedLoopCount &&
9568
9693
IterSpaces[Cnt].CounterVar) {
0 commit comments