Skip to content

[Clang][OpenMP] Add reverse directive #92916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 18, 2024

Conversation

Meinersbur
Copy link
Member

Add the reverse directive which will be introduced in the upcoming OpenMP 6.0 specification. A preview has been published in Technical Report 12.

This is the reverse directive part extracted out of #92030 which included reverse and interchange.

@Meinersbur Meinersbur requested a review from alexey-bataev May 21, 2024 13:45
@Meinersbur Meinersbur marked this pull request as ready for review May 21, 2024 13:46
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules clang:codegen IR generation bugs: mangling, exceptions, etc. flang:openmp clang:openmp OpenMP related changes to Clang openmp:libomp OpenMP host runtime labels May 21, 2024
@llvmbot
Copy link
Member

llvmbot commented May 21, 2024

@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-clang-modules

Author: Michael Kruse (Meinersbur)

Changes

Add the reverse directive which will be introduced in the upcoming OpenMP 6.0 specification. A preview has been published in Technical Report 12.

This is the reverse directive part extracted out of #92030 which included reverse and interchange.


Patch is 144.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92916.diff

30 Files Affected:

  • (modified) clang/include/clang-c/Index.h (+4)
  • (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+3)
  • (modified) clang/include/clang/AST/StmtOpenMP.h (+70-2)
  • (modified) clang/include/clang/Basic/StmtNodes.td (+1)
  • (modified) clang/include/clang/Sema/SemaOpenMP.h (+5)
  • (modified) clang/include/clang/Serialization/ASTBitCodes.h (+1)
  • (modified) clang/lib/AST/StmtOpenMP.cpp (+19)
  • (modified) clang/lib/AST/StmtPrinter.cpp (+5)
  • (modified) clang/lib/AST/StmtProfile.cpp (+4)
  • (modified) clang/lib/Basic/OpenMPKinds.cpp (+2-1)
  • (modified) clang/lib/CodeGen/CGStmt.cpp (+3)
  • (modified) clang/lib/CodeGen/CGStmtOpenMP.cpp (+8)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+1)
  • (modified) clang/lib/Parse/ParseOpenMP.cpp (+2)
  • (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1)
  • (modified) clang/lib/Sema/SemaOpenMP.cpp (+190)
  • (modified) clang/lib/Sema/TreeTransform.h (+11)
  • (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+12)
  • (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+5)
  • (added) clang/test/OpenMP/reverse_ast_print.cpp (+159)
  • (added) clang/test/OpenMP/reverse_codegen.cpp (+1554)
  • (added) clang/test/OpenMP/reverse_messages.cpp (+40)
  • (modified) clang/tools/libclang/CIndex.cpp (+7)
  • (modified) clang/tools/libclang/CXCursor.cpp (+3)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+3)
  • (added) openmp/runtime/test/transform/reverse/foreach.cpp (+162)
  • (added) openmp/runtime/test/transform/reverse/intfor.c (+25)
  • (added) openmp/runtime/test/transform/reverse/iterfor.cpp (+164)
  • (added) openmp/runtime/test/transform/reverse/parallel-wsloop-collapse-foreach.cpp (+285)
  • (added) openmp/runtime/test/transform/reverse/parallel-wsloop-collapse-intfor.cpp (+51)
diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h
index 365b607c74117..c7d63818ece23 100644
--- a/clang/include/clang-c/Index.h
+++ b/clang/include/clang-c/Index.h
@@ -2146,6 +2146,10 @@ enum CXCursorKind {
    */
   CXCursor_OMPScopeDirective = 306,
 
+  /** OpenMP reverse directive.
+   */
+  CXCursor_OMPReverseDirective = 307,
+
   /** OpenACC Compute Construct.
    */
   CXCursor_OpenACCComputeConstruct = 320,
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index f5cefedb07e0e..06b29d59785f6 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3021,6 +3021,9 @@ DEF_TRAVERSE_STMT(OMPTileDirective,
 DEF_TRAVERSE_STMT(OMPUnrollDirective,
                   { TRY_TO(TraverseOMPExecutableDirective(S)); })
 
+DEF_TRAVERSE_STMT(OMPReverseDirective,
+                  { TRY_TO(TraverseOMPExecutableDirective(S)); })
+
 DEF_TRAVERSE_STMT(OMPForDirective,
                   { TRY_TO(TraverseOMPExecutableDirective(S)); })
 
diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h
index f735fa5643aec..4be2e2d3a4605 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -1007,8 +1007,9 @@ class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
   Stmt *getPreInits() const;
 
   static bool classof(const Stmt *T) {
-    return T->getStmtClass() == OMPTileDirectiveClass ||
-           T->getStmtClass() == OMPUnrollDirectiveClass;
+    Stmt::StmtClass C = T->getStmtClass();
+    return C == OMPTileDirectiveClass || C == OMPUnrollDirectiveClass ||
+           C == OMPReverseDirectiveClass;
   }
 };
 
@@ -5711,6 +5712,73 @@ class OMPUnrollDirective final : public OMPLoopTransformationDirective {
   }
 };
 
+/// Represents the '#pragma omp reverse' loop transformation directive.
+///
+/// \code
+/// #pragma omp reverse
+/// for (int i = 0; i < n; ++i)
+///   ...
+/// \endcode
+class OMPReverseDirective final : public OMPLoopTransformationDirective {
+  friend class ASTStmtReader;
+  friend class OMPExecutableDirective;
+
+  /// Offsets of child members.
+  enum {
+    PreInitsOffset = 0,
+    TransformedStmtOffset,
+  };
+
+  explicit OMPReverseDirective(SourceLocation StartLoc, SourceLocation EndLoc)
+      : OMPLoopTransformationDirective(OMPReverseDirectiveClass,
+                                       llvm::omp::OMPD_reverse, StartLoc,
+                                       EndLoc, 1) {}
+
+  void setPreInits(Stmt *PreInits) {
+    Data->getChildren()[PreInitsOffset] = PreInits;
+  }
+
+  void setTransformedStmt(Stmt *S) {
+    Data->getChildren()[TransformedStmtOffset] = S;
+  }
+
+public:
+  /// Create a new AST node representation for '#pragma omp reverse'.
+  ///
+  /// \param C         Context of the AST.
+  /// \param StartLoc  Location of the introducer (e.g. the 'omp' token).
+  /// \param EndLoc    Location of the directive's end (e.g. the tok::eod).
+  /// \param Clauses   The directive's clauses.
+  /// \param AssociatedStmt  The outermost associated loop.
+  /// \param TransformedStmt The loop nest after tiling, or nullptr in
+  ///                        dependent contexts.
+  /// \param PreInits   Helper preinits statements for the loop nest.
+  static OMPReverseDirective *
+  Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+         ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
+         Stmt *TransformedStmt, Stmt *PreInits);
+
+  /// Build an empty '#pragma omp reverse' AST node for deserialization.
+  ///
+  /// \param C          Context of the AST.
+  /// \param NumClauses Number of clauses to allocate.
+  static OMPReverseDirective *CreateEmpty(const ASTContext &C,
+                                          unsigned NumClauses);
+
+  /// Gets/sets the associated loops after the transformation, i.e. after
+  /// de-sugaring.
+  Stmt *getTransformedStmt() const {
+    return Data->getChildren()[TransformedStmtOffset];
+  }
+
+  /// Return preinits statement.
+  Stmt *getPreInits() const { return Data->getChildren()[PreInitsOffset]; }
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == OMPReverseDirectiveClass;
+  }
+};
+
 /// This represents '#pragma omp scan' directive.
 ///
 /// \code
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index 305f19daa4a92..b2e2be5c998bb 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -229,6 +229,7 @@ def OMPSimdDirective : StmtNode<OMPLoopDirective>;
 def OMPLoopTransformationDirective : StmtNode<OMPLoopBasedDirective, 1>;
 def OMPTileDirective : StmtNode<OMPLoopTransformationDirective>;
 def OMPUnrollDirective : StmtNode<OMPLoopTransformationDirective>;
+def OMPReverseDirective : StmtNode<OMPLoopTransformationDirective>;
 def OMPForDirective : StmtNode<OMPLoopDirective>;
 def OMPForSimdDirective : StmtNode<OMPLoopDirective>;
 def OMPSectionsDirective : StmtNode<OMPExecutableDirective>;
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index 51981e1c9a8b9..e36a90ba4e1b9 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -422,6 +422,11 @@ class SemaOpenMP : public SemaBase {
   StmtResult ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
                                         Stmt *AStmt, SourceLocation StartLoc,
                                         SourceLocation EndLoc);
+  /// Called on well-formed '#pragma omp reverse' after parsing of its clauses
+  /// and the associated statement.
+  StmtResult ActOnOpenMPReverseDirective(ArrayRef<OMPClause *> Clauses,
+                                         Stmt *AStmt, SourceLocation StartLoc,
+                                         SourceLocation EndLoc);
   /// Called on well-formed '\#pragma omp for' after parsing
   /// of the associated statement.
   StmtResult
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index fe1bd47348be1..dee0d073557cc 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1856,6 +1856,7 @@ enum StmtCode {
   STMT_OMP_SIMD_DIRECTIVE,
   STMT_OMP_TILE_DIRECTIVE,
   STMT_OMP_UNROLL_DIRECTIVE,
+  STMT_OMP_REVERSE_DIRECTIVE,
   STMT_OMP_FOR_DIRECTIVE,
   STMT_OMP_FOR_SIMD_DIRECTIVE,
   STMT_OMP_SECTIONS_DIRECTIVE,
diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index d8519b2071e6d..0be0d9d2cfa94 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -449,6 +449,25 @@ OMPUnrollDirective *OMPUnrollDirective::CreateEmpty(const ASTContext &C,
       SourceLocation(), SourceLocation());
 }
 
+OMPReverseDirective *
+OMPReverseDirective::Create(const ASTContext &C, SourceLocation StartLoc,
+                            SourceLocation EndLoc,
+                            ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
+                            Stmt *TransformedStmt, Stmt *PreInits) {
+  OMPReverseDirective *Dir = createDirective<OMPReverseDirective>(
+      C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc);
+  Dir->setTransformedStmt(TransformedStmt);
+  Dir->setPreInits(PreInits);
+  return Dir;
+}
+
+OMPReverseDirective *OMPReverseDirective::CreateEmpty(const ASTContext &C,
+                                                      unsigned NumClauses) {
+  return createEmptyDirective<OMPReverseDirective>(
+      C, NumClauses, /*HasAssociatedStmt=*/true, TransformedStmtOffset + 1,
+      SourceLocation(), SourceLocation());
+}
+
 OMPForSimdDirective *
 OMPForSimdDirective::Create(const ASTContext &C, SourceLocation StartLoc,
                             SourceLocation EndLoc, unsigned CollapsedNum,
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index be2d5a2eb6b46..64b481f680311 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -763,6 +763,11 @@ void StmtPrinter::VisitOMPUnrollDirective(OMPUnrollDirective *Node) {
   PrintOMPExecutableDirective(Node);
 }
 
+void StmtPrinter::VisitOMPReverseDirective(OMPReverseDirective *Node) {
+  Indent() << "#pragma omp reverse";
+  PrintOMPExecutableDirective(Node);
+}
+
 void StmtPrinter::VisitOMPForDirective(OMPForDirective *Node) {
   Indent() << "#pragma omp for";
   PrintOMPExecutableDirective(Node);
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index caab4ab0ef160..f0e1c9548de72 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -985,6 +985,10 @@ void StmtProfiler::VisitOMPUnrollDirective(const OMPUnrollDirective *S) {
   VisitOMPLoopTransformationDirective(S);
 }
 
+void StmtProfiler::VisitOMPReverseDirective(const OMPReverseDirective *S) {
+  VisitOMPLoopTransformationDirective(S);
+}
+
 void StmtProfiler::VisitOMPForDirective(const OMPForDirective *S) {
   VisitOMPLoopDirective(S);
 }
diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index b3e9affbb3e58..803808c38e2fe 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -684,7 +684,7 @@ bool clang::isOpenMPLoopBoundSharingDirective(OpenMPDirectiveKind Kind) {
 }
 
 bool clang::isOpenMPLoopTransformationDirective(OpenMPDirectiveKind DKind) {
-  return DKind == OMPD_tile || DKind == OMPD_unroll;
+  return DKind == OMPD_tile || DKind == OMPD_unroll || DKind == OMPD_reverse;
 }
 
 bool clang::isOpenMPCombinedParallelADirective(OpenMPDirectiveKind DKind) {
@@ -808,6 +808,7 @@ void clang::getOpenMPCaptureRegions(
     break;
   case OMPD_tile:
   case OMPD_unroll:
+  case OMPD_reverse:
     // loop transformations do not introduce captures.
     break;
   case OMPD_threadprivate:
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 99daaa14cf3fe..93c2f8900dd12 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -222,6 +222,9 @@ void CodeGenFunction::EmitStmt(const Stmt *S, ArrayRef<const Attr *> Attrs) {
   case Stmt::OMPUnrollDirectiveClass:
     EmitOMPUnrollDirective(cast<OMPUnrollDirective>(*S));
     break;
+  case Stmt::OMPReverseDirectiveClass:
+    EmitOMPReverseDirective(cast<OMPReverseDirective>(*S));
+    break;
   case Stmt::OMPForDirectiveClass:
     EmitOMPForDirective(cast<OMPForDirective>(*S));
     break;
diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index 6410f9e102c90..ad6c044aa483b 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -187,6 +187,8 @@ class OMPLoopScope : public CodeGenFunction::RunCleanupsScope {
       PreInits = Tile->getPreInits();
     } else if (const auto *Unroll = dyn_cast<OMPUnrollDirective>(&S)) {
       PreInits = Unroll->getPreInits();
+    } else if (const auto *Reverse = dyn_cast<OMPReverseDirective>(&S)) {
+      PreInits = Reverse->getPreInits();
     } else {
       llvm_unreachable("Unknown loop-based directive kind.");
     }
@@ -2762,6 +2764,12 @@ void CodeGenFunction::EmitOMPTileDirective(const OMPTileDirective &S) {
   EmitStmt(S.getTransformedStmt());
 }
 
+void CodeGenFunction::EmitOMPReverseDirective(const OMPReverseDirective &S) {
+  // Emit the de-sugared statement.
+  OMPTransformDirectiveScopeRAII ReverseScope(*this, &S);
+  EmitStmt(S.getTransformedStmt());
+}
+
 void CodeGenFunction::EmitOMPUnrollDirective(const OMPUnrollDirective &S) {
   bool UseOMPIRBuilder = CGM.getLangOpts().OpenMPIRBuilder;
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 5f3ee7eb943f9..ac738e1e82886 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -3807,6 +3807,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   void EmitOMPSimdDirective(const OMPSimdDirective &S);
   void EmitOMPTileDirective(const OMPTileDirective &S);
   void EmitOMPUnrollDirective(const OMPUnrollDirective &S);
+  void EmitOMPReverseDirective(const OMPReverseDirective &S);
   void EmitOMPForDirective(const OMPForDirective &S);
   void EmitOMPForSimdDirective(const OMPForSimdDirective &S);
   void EmitOMPSectionsDirective(const OMPSectionsDirective &S);
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index e959dd6378f46..57fcf6ce520ac 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -2384,6 +2384,7 @@ Parser::DeclGroupPtrTy Parser::ParseOpenMPDeclarativeDirectiveWithExtDecl(
   case OMPD_simd:
   case OMPD_tile:
   case OMPD_unroll:
+  case OMPD_reverse:
   case OMPD_task:
   case OMPD_taskyield:
   case OMPD_barrier:
@@ -2802,6 +2803,7 @@ StmtResult Parser::ParseOpenMPDeclarativeOrExecutableDirective(
   case OMPD_simd:
   case OMPD_tile:
   case OMPD_unroll:
+  case OMPD_reverse:
   case OMPD_for:
   case OMPD_for_simd:
   case OMPD_sections:
diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp
index 41bf273d12f2f..4de7183cde281 100644
--- a/clang/lib/Sema/SemaExceptionSpec.cpp
+++ b/clang/lib/Sema/SemaExceptionSpec.cpp
@@ -1486,6 +1486,7 @@ CanThrowResult Sema::canThrow(const Stmt *S) {
   case Stmt::OMPSimdDirectiveClass:
   case Stmt::OMPTileDirectiveClass:
   case Stmt::OMPUnrollDirectiveClass:
+  case Stmt::OMPReverseDirectiveClass:
   case Stmt::OMPSingleDirectiveClass:
   case Stmt::OMPTargetDataDirectiveClass:
   case Stmt::OMPTargetDirectiveClass:
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 663dbb927250e..7b9898704eb1c 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -4334,6 +4334,7 @@ void SemaOpenMP::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind,
   case OMPD_masked:
   case OMPD_tile:
   case OMPD_unroll:
+  case OMPD_reverse:
     break;
   case OMPD_loop:
     // TODO: 'loop' may require additional parameters depending on the binding.
@@ -6546,6 +6547,10 @@ StmtResult SemaOpenMP::ActOnOpenMPExecutableDirective(
     Res = ActOnOpenMPUnrollDirective(ClausesWithImplicit, AStmt, StartLoc,
                                      EndLoc);
     break;
+  case OMPD_reverse:
+    Res = ActOnOpenMPReverseDirective(ClausesWithImplicit, AStmt, StartLoc,
+                                      EndLoc);
+    break;
   case OMPD_for:
     Res = ActOnOpenMPForDirective(ClausesWithImplicit, AStmt, StartLoc, EndLoc,
                                   VarsWithInheritedDSA);
@@ -15121,6 +15126,8 @@ bool SemaOpenMP::checkTransformableLoopNest(
           DependentPreInits = Dir->getPreInits();
         else if (auto *Dir = dyn_cast<OMPUnrollDirective>(Transform))
           DependentPreInits = Dir->getPreInits();
+        else if (auto *Dir = dyn_cast<OMPReverseDirective>(Transform))
+          DependentPreInits = Dir->getPreInits();
         else
           llvm_unreachable("Unhandled loop transformation");
         if (!DependentPreInits)
@@ -15746,6 +15753,189 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
                                     buildPreInits(Context, PreInits));
 }
 
+StmtResult
+SemaOpenMP::ActOnOpenMPReverseDirective(ArrayRef<OMPClause *> Clauses,
+                                        Stmt *AStmt, SourceLocation StartLoc,
+                                        SourceLocation EndLoc) {
+  ASTContext &Context = getASTContext();
+  Scope *CurScope = SemaRef.getCurScope();
+  assert(Clauses.empty() && "reverse directive does not accept any clauses; "
+                            "must have beed checked before");
+
+  // Empty statement should only be possible if there already was an error.
+  if (!AStmt)
+    return StmtError();
+
+  constexpr unsigned NumLoops = 1;
+  Stmt *Body = nullptr;
+  SmallVector<OMPLoopBasedDirective::HelperExprs, NumLoops> LoopHelpers(
+      NumLoops);
+  SmallVector<SmallVector<Stmt *, 0>, NumLoops + 1> OriginalInits;
+  if (!checkTransformableLoopNest(OMPD_reverse, AStmt, NumLoops, LoopHelpers,
+                                  Body, OriginalInits))
+    return StmtError();
+
+  // Delay applying the transformation to when template is completely
+  // instantiated.
+  if (SemaRef.CurContext->isDependentContext())
+    return OMPReverseDirective::Create(Context, StartLoc, EndLoc, Clauses,
+                                       AStmt, nullptr, nullptr);
+
+  assert(LoopHelpers.size() == NumLoops &&
+         "Expecting a single-dimensional loop iteration space");
+  assert(OriginalInits.size() == NumLoops &&
+         "Expecting a single-dimensional loop iteration space");
+  OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front();
+
+  // Find the loop statement.
+  Stmt *LoopStmt = nullptr;
+  collectLoopStmts(AStmt, {LoopStmt});
+
+  // Determine the PreInit declarations.
+  SmallVector<Stmt *> PreInits;
+  addLoopPreInits(Context, LoopHelper, LoopStmt, OriginalInits[0], PreInits);
+
+  auto *IterationVarRef = cast<DeclRefExpr>(LoopHelper.IterationVarRef);
+  QualType IVTy = IterationVarRef->getType();
+  uint64_t IVWidth = Context.getTypeSize(IVTy);
+  auto *OrigVar = cast<DeclRefExpr>(LoopHelper.Counters.front());
+
+  // Iteration variable SourceLocations.
+  SourceLocation OrigVarLoc = OrigVar->getExprLoc();
+  SourceLocation OrigVarLocBegin = OrigVar->getBeginLoc();
+  SourceLocation OrigVarLocEnd = OrigVar->getEndLoc();
+
+  // Locations pointing to the transformation.
+  SourceLocation TransformLoc = StartLoc;
+  SourceLocation TransformLocBegin = StartLoc;
+  SourceLocation TransformLocEnd = EndLoc;
+
+  // Internal variable names.
+  std::string OrigVarName = OrigVar->getNameInfo().getAsString();
+  std::string TripCountName = (Twine(".tripcount.") + OrigVarName).str();
+  std::string ForwardIVName = (Twine(".forward.iv.") + OrigVarName).str();
+  std::string ReversedIVName = (Twine(".reversed.iv.") + OrigVarName).str();
+
+  // LoopHelper.Updates will read the logical iteration number from
+  // LoopHelper.IterationVarRef, compute the value of the user loop counter of
+  // that logical iteration from it, then assign it to the user loop counter
+  // variable. We cannot directly use LoopHelper.IterationVarRef as the
+  // induction variable of the generated loop because it may cause an underflow:
+  // \code
+  //   for (unsigned i = 0; i < n; ++i)
+  //     body(i);
+  // \endcode
+  //
+  // Naive reversal:
+  // \code
+  //   for (unsigned i = n-1; i >= 0; --i)
+  //     body(i);
+  // \endcode
+  //
+  // Instead, we introduce a new iteration variable representing the logical
+  // iteration counter of the original loop, convert it to the logical iteration
+  // number of the reversed loop, then let LoopHelper.Updates compute the user's
+  // loop iteration variable from it.
+  // \code
+  //   for (auto .forward.iv = 0; .forward.iv < n; ++.forward.iv) {
+  //     auto .reversed.iv = n - .forward.iv - 1;
+  //     i = (.reversed.iv + 0) * 1;                // LoopHelper.Updates
+  //     body(i);                                   // Body
+  //   }
+  // \endcode
+
+  // Subexpressions with more than one use. One of the constraints of an AST is
+  // that every node object must appear at most once, hence we define a lambda
+  // that creates a new AST node at every use.
+  CaptureVars CopyTransformer(SemaRef);
+  auto MakeNumIterations = [&CopyTransformer, &LoopHelper]() -> Expr * {
+    return AssertSuccess(
+        CopyTransformer.TransformExpr(LoopHelper.NumIterations));
+  };
+
+  // Create the iteration variable for the forward loop (from 0 to n-1).
+  VarDecl *ForwardIVDecl =
+      buildVarDecl(SemaRef, {}, IVTy, ForwardIVName, nullptr, OrigVar);
+  auto MakeForwardRef = [&SemaRef = this->SemaRef, ForwardIVDecl, IVTy,
+                         OrigVarLoc]() {
+    return buildDeclRefExpr(SemaRef, ForwardIVDecl, IVTy, OrigVarLoc);
+  };
+
+  // Iteration variable for the reversed induction variable (from n-1 downto 0):
+  // Reuse the iteration variable created by checkOpenMPLoop.
+  auto *ReversedIVDecl = cast<VarDecl>(IterationVarRef->getDecl());
+  ReversedIVDecl->setDeclName(
+      &SemaRef.PP.getIdentifierTable().get(ReversedIVName));
+
+  // For init-statement:
+  // \code
+  //   auto .forward.iv = 0
+  // \endcode
+  IntegerLiteral *Zero =
+      IntegerLiteral::Create(Context, llvm::APInt::getZero(IVWi...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 21, 2024

@llvm/pr-subscribers-clang

Author: Michael Kruse (Meinersbur)

Changes

Add the reverse directive which will be introduced in the upcoming OpenMP 6.0 specification. A preview has been published in Technical Report 12.

This is the reverse directive part extracted out of #92030 which included reverse and interchange.


Patch is 144.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92916.diff

30 Files Affected:

  • (modified) clang/include/clang-c/Index.h (+4)
  • (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+3)
  • (modified) clang/include/clang/AST/StmtOpenMP.h (+70-2)
  • (modified) clang/include/clang/Basic/StmtNodes.td (+1)
  • (modified) clang/include/clang/Sema/SemaOpenMP.h (+5)
  • (modified) clang/include/clang/Serialization/ASTBitCodes.h (+1)
  • (modified) clang/lib/AST/StmtOpenMP.cpp (+19)
  • (modified) clang/lib/AST/StmtPrinter.cpp (+5)
  • (modified) clang/lib/AST/StmtProfile.cpp (+4)
  • (modified) clang/lib/Basic/OpenMPKinds.cpp (+2-1)
  • (modified) clang/lib/CodeGen/CGStmt.cpp (+3)
  • (modified) clang/lib/CodeGen/CGStmtOpenMP.cpp (+8)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+1)
  • (modified) clang/lib/Parse/ParseOpenMP.cpp (+2)
  • (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1)
  • (modified) clang/lib/Sema/SemaOpenMP.cpp (+190)
  • (modified) clang/lib/Sema/TreeTransform.h (+11)
  • (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+12)
  • (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+5)
  • (added) clang/test/OpenMP/reverse_ast_print.cpp (+159)
  • (added) clang/test/OpenMP/reverse_codegen.cpp (+1554)
  • (added) clang/test/OpenMP/reverse_messages.cpp (+40)
  • (modified) clang/tools/libclang/CIndex.cpp (+7)
  • (modified) clang/tools/libclang/CXCursor.cpp (+3)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+3)
  • (added) openmp/runtime/test/transform/reverse/foreach.cpp (+162)
  • (added) openmp/runtime/test/transform/reverse/intfor.c (+25)
  • (added) openmp/runtime/test/transform/reverse/iterfor.cpp (+164)
  • (added) openmp/runtime/test/transform/reverse/parallel-wsloop-collapse-foreach.cpp (+285)
  • (added) openmp/runtime/test/transform/reverse/parallel-wsloop-collapse-intfor.cpp (+51)
diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h
index 365b607c74117..c7d63818ece23 100644
--- a/clang/include/clang-c/Index.h
+++ b/clang/include/clang-c/Index.h
@@ -2146,6 +2146,10 @@ enum CXCursorKind {
    */
   CXCursor_OMPScopeDirective = 306,
 
+  /** OpenMP reverse directive.
+   */
+  CXCursor_OMPReverseDirective = 307,
+
   /** OpenACC Compute Construct.
    */
   CXCursor_OpenACCComputeConstruct = 320,
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index f5cefedb07e0e..06b29d59785f6 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3021,6 +3021,9 @@ DEF_TRAVERSE_STMT(OMPTileDirective,
 DEF_TRAVERSE_STMT(OMPUnrollDirective,
                   { TRY_TO(TraverseOMPExecutableDirective(S)); })
 
+DEF_TRAVERSE_STMT(OMPReverseDirective,
+                  { TRY_TO(TraverseOMPExecutableDirective(S)); })
+
 DEF_TRAVERSE_STMT(OMPForDirective,
                   { TRY_TO(TraverseOMPExecutableDirective(S)); })
 
diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h
index f735fa5643aec..4be2e2d3a4605 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -1007,8 +1007,9 @@ class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
   Stmt *getPreInits() const;
 
   static bool classof(const Stmt *T) {
-    return T->getStmtClass() == OMPTileDirectiveClass ||
-           T->getStmtClass() == OMPUnrollDirectiveClass;
+    Stmt::StmtClass C = T->getStmtClass();
+    return C == OMPTileDirectiveClass || C == OMPUnrollDirectiveClass ||
+           C == OMPReverseDirectiveClass;
   }
 };
 
@@ -5711,6 +5712,73 @@ class OMPUnrollDirective final : public OMPLoopTransformationDirective {
   }
 };
 
+/// Represents the '#pragma omp reverse' loop transformation directive.
+///
+/// \code
+/// #pragma omp reverse
+/// for (int i = 0; i < n; ++i)
+///   ...
+/// \endcode
+class OMPReverseDirective final : public OMPLoopTransformationDirective {
+  friend class ASTStmtReader;
+  friend class OMPExecutableDirective;
+
+  /// Offsets of child members.
+  enum {
+    PreInitsOffset = 0,
+    TransformedStmtOffset,
+  };
+
+  explicit OMPReverseDirective(SourceLocation StartLoc, SourceLocation EndLoc)
+      : OMPLoopTransformationDirective(OMPReverseDirectiveClass,
+                                       llvm::omp::OMPD_reverse, StartLoc,
+                                       EndLoc, 1) {}
+
+  void setPreInits(Stmt *PreInits) {
+    Data->getChildren()[PreInitsOffset] = PreInits;
+  }
+
+  void setTransformedStmt(Stmt *S) {
+    Data->getChildren()[TransformedStmtOffset] = S;
+  }
+
+public:
+  /// Create a new AST node representation for '#pragma omp reverse'.
+  ///
+  /// \param C         Context of the AST.
+  /// \param StartLoc  Location of the introducer (e.g. the 'omp' token).
+  /// \param EndLoc    Location of the directive's end (e.g. the tok::eod).
+  /// \param Clauses   The directive's clauses.
+  /// \param AssociatedStmt  The outermost associated loop.
+  /// \param TransformedStmt The loop nest after tiling, or nullptr in
+  ///                        dependent contexts.
+  /// \param PreInits   Helper preinits statements for the loop nest.
+  static OMPReverseDirective *
+  Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+         ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
+         Stmt *TransformedStmt, Stmt *PreInits);
+
+  /// Build an empty '#pragma omp reverse' AST node for deserialization.
+  ///
+  /// \param C          Context of the AST.
+  /// \param NumClauses Number of clauses to allocate.
+  static OMPReverseDirective *CreateEmpty(const ASTContext &C,
+                                          unsigned NumClauses);
+
+  /// Gets/sets the associated loops after the transformation, i.e. after
+  /// de-sugaring.
+  Stmt *getTransformedStmt() const {
+    return Data->getChildren()[TransformedStmtOffset];
+  }
+
+  /// Return preinits statement.
+  Stmt *getPreInits() const { return Data->getChildren()[PreInitsOffset]; }
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == OMPReverseDirectiveClass;
+  }
+};
+
 /// This represents '#pragma omp scan' directive.
 ///
 /// \code
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index 305f19daa4a92..b2e2be5c998bb 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -229,6 +229,7 @@ def OMPSimdDirective : StmtNode<OMPLoopDirective>;
 def OMPLoopTransformationDirective : StmtNode<OMPLoopBasedDirective, 1>;
 def OMPTileDirective : StmtNode<OMPLoopTransformationDirective>;
 def OMPUnrollDirective : StmtNode<OMPLoopTransformationDirective>;
+def OMPReverseDirective : StmtNode<OMPLoopTransformationDirective>;
 def OMPForDirective : StmtNode<OMPLoopDirective>;
 def OMPForSimdDirective : StmtNode<OMPLoopDirective>;
 def OMPSectionsDirective : StmtNode<OMPExecutableDirective>;
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index 51981e1c9a8b9..e36a90ba4e1b9 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -422,6 +422,11 @@ class SemaOpenMP : public SemaBase {
   StmtResult ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
                                         Stmt *AStmt, SourceLocation StartLoc,
                                         SourceLocation EndLoc);
+  /// Called on well-formed '#pragma omp reverse' after parsing of its clauses
+  /// and the associated statement.
+  StmtResult ActOnOpenMPReverseDirective(ArrayRef<OMPClause *> Clauses,
+                                         Stmt *AStmt, SourceLocation StartLoc,
+                                         SourceLocation EndLoc);
   /// Called on well-formed '\#pragma omp for' after parsing
   /// of the associated statement.
   StmtResult
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index fe1bd47348be1..dee0d073557cc 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1856,6 +1856,7 @@ enum StmtCode {
   STMT_OMP_SIMD_DIRECTIVE,
   STMT_OMP_TILE_DIRECTIVE,
   STMT_OMP_UNROLL_DIRECTIVE,
+  STMT_OMP_REVERSE_DIRECTIVE,
   STMT_OMP_FOR_DIRECTIVE,
   STMT_OMP_FOR_SIMD_DIRECTIVE,
   STMT_OMP_SECTIONS_DIRECTIVE,
diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index d8519b2071e6d..0be0d9d2cfa94 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -449,6 +449,25 @@ OMPUnrollDirective *OMPUnrollDirective::CreateEmpty(const ASTContext &C,
       SourceLocation(), SourceLocation());
 }
 
+OMPReverseDirective *
+OMPReverseDirective::Create(const ASTContext &C, SourceLocation StartLoc,
+                            SourceLocation EndLoc,
+                            ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
+                            Stmt *TransformedStmt, Stmt *PreInits) {
+  OMPReverseDirective *Dir = createDirective<OMPReverseDirective>(
+      C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc);
+  Dir->setTransformedStmt(TransformedStmt);
+  Dir->setPreInits(PreInits);
+  return Dir;
+}
+
+OMPReverseDirective *OMPReverseDirective::CreateEmpty(const ASTContext &C,
+                                                      unsigned NumClauses) {
+  return createEmptyDirective<OMPReverseDirective>(
+      C, NumClauses, /*HasAssociatedStmt=*/true, TransformedStmtOffset + 1,
+      SourceLocation(), SourceLocation());
+}
+
 OMPForSimdDirective *
 OMPForSimdDirective::Create(const ASTContext &C, SourceLocation StartLoc,
                             SourceLocation EndLoc, unsigned CollapsedNum,
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index be2d5a2eb6b46..64b481f680311 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -763,6 +763,11 @@ void StmtPrinter::VisitOMPUnrollDirective(OMPUnrollDirective *Node) {
   PrintOMPExecutableDirective(Node);
 }
 
+void StmtPrinter::VisitOMPReverseDirective(OMPReverseDirective *Node) {
+  Indent() << "#pragma omp reverse";
+  PrintOMPExecutableDirective(Node);
+}
+
 void StmtPrinter::VisitOMPForDirective(OMPForDirective *Node) {
   Indent() << "#pragma omp for";
   PrintOMPExecutableDirective(Node);
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index caab4ab0ef160..f0e1c9548de72 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -985,6 +985,10 @@ void StmtProfiler::VisitOMPUnrollDirective(const OMPUnrollDirective *S) {
   VisitOMPLoopTransformationDirective(S);
 }
 
+void StmtProfiler::VisitOMPReverseDirective(const OMPReverseDirective *S) {
+  VisitOMPLoopTransformationDirective(S);
+}
+
 void StmtProfiler::VisitOMPForDirective(const OMPForDirective *S) {
   VisitOMPLoopDirective(S);
 }
diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index b3e9affbb3e58..803808c38e2fe 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -684,7 +684,7 @@ bool clang::isOpenMPLoopBoundSharingDirective(OpenMPDirectiveKind Kind) {
 }
 
 bool clang::isOpenMPLoopTransformationDirective(OpenMPDirectiveKind DKind) {
-  return DKind == OMPD_tile || DKind == OMPD_unroll;
+  return DKind == OMPD_tile || DKind == OMPD_unroll || DKind == OMPD_reverse;
 }
 
 bool clang::isOpenMPCombinedParallelADirective(OpenMPDirectiveKind DKind) {
@@ -808,6 +808,7 @@ void clang::getOpenMPCaptureRegions(
     break;
   case OMPD_tile:
   case OMPD_unroll:
+  case OMPD_reverse:
     // loop transformations do not introduce captures.
     break;
   case OMPD_threadprivate:
diff --git a/clang/lib/CodeGen/CGStmt.cpp b/clang/lib/CodeGen/CGStmt.cpp
index 99daaa14cf3fe..93c2f8900dd12 100644
--- a/clang/lib/CodeGen/CGStmt.cpp
+++ b/clang/lib/CodeGen/CGStmt.cpp
@@ -222,6 +222,9 @@ void CodeGenFunction::EmitStmt(const Stmt *S, ArrayRef<const Attr *> Attrs) {
   case Stmt::OMPUnrollDirectiveClass:
     EmitOMPUnrollDirective(cast<OMPUnrollDirective>(*S));
     break;
+  case Stmt::OMPReverseDirectiveClass:
+    EmitOMPReverseDirective(cast<OMPReverseDirective>(*S));
+    break;
   case Stmt::OMPForDirectiveClass:
     EmitOMPForDirective(cast<OMPForDirective>(*S));
     break;
diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index 6410f9e102c90..ad6c044aa483b 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -187,6 +187,8 @@ class OMPLoopScope : public CodeGenFunction::RunCleanupsScope {
       PreInits = Tile->getPreInits();
     } else if (const auto *Unroll = dyn_cast<OMPUnrollDirective>(&S)) {
       PreInits = Unroll->getPreInits();
+    } else if (const auto *Reverse = dyn_cast<OMPReverseDirective>(&S)) {
+      PreInits = Reverse->getPreInits();
     } else {
       llvm_unreachable("Unknown loop-based directive kind.");
     }
@@ -2762,6 +2764,12 @@ void CodeGenFunction::EmitOMPTileDirective(const OMPTileDirective &S) {
   EmitStmt(S.getTransformedStmt());
 }
 
+void CodeGenFunction::EmitOMPReverseDirective(const OMPReverseDirective &S) {
+  // Emit the de-sugared statement.
+  OMPTransformDirectiveScopeRAII ReverseScope(*this, &S);
+  EmitStmt(S.getTransformedStmt());
+}
+
 void CodeGenFunction::EmitOMPUnrollDirective(const OMPUnrollDirective &S) {
   bool UseOMPIRBuilder = CGM.getLangOpts().OpenMPIRBuilder;
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 5f3ee7eb943f9..ac738e1e82886 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -3807,6 +3807,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   void EmitOMPSimdDirective(const OMPSimdDirective &S);
   void EmitOMPTileDirective(const OMPTileDirective &S);
   void EmitOMPUnrollDirective(const OMPUnrollDirective &S);
+  void EmitOMPReverseDirective(const OMPReverseDirective &S);
   void EmitOMPForDirective(const OMPForDirective &S);
   void EmitOMPForSimdDirective(const OMPForSimdDirective &S);
   void EmitOMPSectionsDirective(const OMPSectionsDirective &S);
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index e959dd6378f46..57fcf6ce520ac 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -2384,6 +2384,7 @@ Parser::DeclGroupPtrTy Parser::ParseOpenMPDeclarativeDirectiveWithExtDecl(
   case OMPD_simd:
   case OMPD_tile:
   case OMPD_unroll:
+  case OMPD_reverse:
   case OMPD_task:
   case OMPD_taskyield:
   case OMPD_barrier:
@@ -2802,6 +2803,7 @@ StmtResult Parser::ParseOpenMPDeclarativeOrExecutableDirective(
   case OMPD_simd:
   case OMPD_tile:
   case OMPD_unroll:
+  case OMPD_reverse:
   case OMPD_for:
   case OMPD_for_simd:
   case OMPD_sections:
diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp
index 41bf273d12f2f..4de7183cde281 100644
--- a/clang/lib/Sema/SemaExceptionSpec.cpp
+++ b/clang/lib/Sema/SemaExceptionSpec.cpp
@@ -1486,6 +1486,7 @@ CanThrowResult Sema::canThrow(const Stmt *S) {
   case Stmt::OMPSimdDirectiveClass:
   case Stmt::OMPTileDirectiveClass:
   case Stmt::OMPUnrollDirectiveClass:
+  case Stmt::OMPReverseDirectiveClass:
   case Stmt::OMPSingleDirectiveClass:
   case Stmt::OMPTargetDataDirectiveClass:
   case Stmt::OMPTargetDirectiveClass:
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 663dbb927250e..7b9898704eb1c 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -4334,6 +4334,7 @@ void SemaOpenMP::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind,
   case OMPD_masked:
   case OMPD_tile:
   case OMPD_unroll:
+  case OMPD_reverse:
     break;
   case OMPD_loop:
     // TODO: 'loop' may require additional parameters depending on the binding.
@@ -6546,6 +6547,10 @@ StmtResult SemaOpenMP::ActOnOpenMPExecutableDirective(
     Res = ActOnOpenMPUnrollDirective(ClausesWithImplicit, AStmt, StartLoc,
                                      EndLoc);
     break;
+  case OMPD_reverse:
+    Res = ActOnOpenMPReverseDirective(ClausesWithImplicit, AStmt, StartLoc,
+                                      EndLoc);
+    break;
   case OMPD_for:
     Res = ActOnOpenMPForDirective(ClausesWithImplicit, AStmt, StartLoc, EndLoc,
                                   VarsWithInheritedDSA);
@@ -15121,6 +15126,8 @@ bool SemaOpenMP::checkTransformableLoopNest(
           DependentPreInits = Dir->getPreInits();
         else if (auto *Dir = dyn_cast<OMPUnrollDirective>(Transform))
           DependentPreInits = Dir->getPreInits();
+        else if (auto *Dir = dyn_cast<OMPReverseDirective>(Transform))
+          DependentPreInits = Dir->getPreInits();
         else
           llvm_unreachable("Unhandled loop transformation");
         if (!DependentPreInits)
@@ -15746,6 +15753,189 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
                                     buildPreInits(Context, PreInits));
 }
 
+StmtResult
+SemaOpenMP::ActOnOpenMPReverseDirective(ArrayRef<OMPClause *> Clauses,
+                                        Stmt *AStmt, SourceLocation StartLoc,
+                                        SourceLocation EndLoc) {
+  ASTContext &Context = getASTContext();
+  Scope *CurScope = SemaRef.getCurScope();
+  assert(Clauses.empty() && "reverse directive does not accept any clauses; "
+                            "must have beed checked before");
+
+  // Empty statement should only be possible if there already was an error.
+  if (!AStmt)
+    return StmtError();
+
+  constexpr unsigned NumLoops = 1;
+  Stmt *Body = nullptr;
+  SmallVector<OMPLoopBasedDirective::HelperExprs, NumLoops> LoopHelpers(
+      NumLoops);
+  SmallVector<SmallVector<Stmt *, 0>, NumLoops + 1> OriginalInits;
+  if (!checkTransformableLoopNest(OMPD_reverse, AStmt, NumLoops, LoopHelpers,
+                                  Body, OriginalInits))
+    return StmtError();
+
+  // Delay applying the transformation to when template is completely
+  // instantiated.
+  if (SemaRef.CurContext->isDependentContext())
+    return OMPReverseDirective::Create(Context, StartLoc, EndLoc, Clauses,
+                                       AStmt, nullptr, nullptr);
+
+  assert(LoopHelpers.size() == NumLoops &&
+         "Expecting a single-dimensional loop iteration space");
+  assert(OriginalInits.size() == NumLoops &&
+         "Expecting a single-dimensional loop iteration space");
+  OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front();
+
+  // Find the loop statement.
+  Stmt *LoopStmt = nullptr;
+  collectLoopStmts(AStmt, {LoopStmt});
+
+  // Determine the PreInit declarations.
+  SmallVector<Stmt *> PreInits;
+  addLoopPreInits(Context, LoopHelper, LoopStmt, OriginalInits[0], PreInits);
+
+  auto *IterationVarRef = cast<DeclRefExpr>(LoopHelper.IterationVarRef);
+  QualType IVTy = IterationVarRef->getType();
+  uint64_t IVWidth = Context.getTypeSize(IVTy);
+  auto *OrigVar = cast<DeclRefExpr>(LoopHelper.Counters.front());
+
+  // Iteration variable SourceLocations.
+  SourceLocation OrigVarLoc = OrigVar->getExprLoc();
+  SourceLocation OrigVarLocBegin = OrigVar->getBeginLoc();
+  SourceLocation OrigVarLocEnd = OrigVar->getEndLoc();
+
+  // Locations pointing to the transformation.
+  SourceLocation TransformLoc = StartLoc;
+  SourceLocation TransformLocBegin = StartLoc;
+  SourceLocation TransformLocEnd = EndLoc;
+
+  // Internal variable names.
+  std::string OrigVarName = OrigVar->getNameInfo().getAsString();
+  std::string TripCountName = (Twine(".tripcount.") + OrigVarName).str();
+  std::string ForwardIVName = (Twine(".forward.iv.") + OrigVarName).str();
+  std::string ReversedIVName = (Twine(".reversed.iv.") + OrigVarName).str();
+
+  // LoopHelper.Updates will read the logical iteration number from
+  // LoopHelper.IterationVarRef, compute the value of the user loop counter of
+  // that logical iteration from it, then assign it to the user loop counter
+  // variable. We cannot directly use LoopHelper.IterationVarRef as the
+  // induction variable of the generated loop because it may cause an underflow:
+  // \code
+  //   for (unsigned i = 0; i < n; ++i)
+  //     body(i);
+  // \endcode
+  //
+  // Naive reversal:
+  // \code
+  //   for (unsigned i = n-1; i >= 0; --i)
+  //     body(i);
+  // \endcode
+  //
+  // Instead, we introduce a new iteration variable representing the logical
+  // iteration counter of the original loop, convert it to the logical iteration
+  // number of the reversed loop, then let LoopHelper.Updates compute the user's
+  // loop iteration variable from it.
+  // \code
+  //   for (auto .forward.iv = 0; .forward.iv < n; ++.forward.iv) {
+  //     auto .reversed.iv = n - .forward.iv - 1;
+  //     i = (.reversed.iv + 0) * 1;                // LoopHelper.Updates
+  //     body(i);                                   // Body
+  //   }
+  // \endcode
+
+  // Subexpressions with more than one use. One of the constraints of an AST is
+  // that every node object must appear at most once, hence we define a lambda
+  // that creates a new AST node at every use.
+  CaptureVars CopyTransformer(SemaRef);
+  auto MakeNumIterations = [&CopyTransformer, &LoopHelper]() -> Expr * {
+    return AssertSuccess(
+        CopyTransformer.TransformExpr(LoopHelper.NumIterations));
+  };
+
+  // Create the iteration variable for the forward loop (from 0 to n-1).
+  VarDecl *ForwardIVDecl =
+      buildVarDecl(SemaRef, {}, IVTy, ForwardIVName, nullptr, OrigVar);
+  auto MakeForwardRef = [&SemaRef = this->SemaRef, ForwardIVDecl, IVTy,
+                         OrigVarLoc]() {
+    return buildDeclRefExpr(SemaRef, ForwardIVDecl, IVTy, OrigVarLoc);
+  };
+
+  // Iteration variable for the reversed induction variable (from n-1 downto 0):
+  // Reuse the iteration variable created by checkOpenMPLoop.
+  auto *ReversedIVDecl = cast<VarDecl>(IterationVarRef->getDecl());
+  ReversedIVDecl->setDeclName(
+      &SemaRef.PP.getIdentifierTable().get(ReversedIVName));
+
+  // For init-statement:
+  // \code
+  //   auto .forward.iv = 0
+  // \endcode
+  IntegerLiteral *Zero =
+      IntegerLiteral::Create(Context, llvm::APInt::getZero(IVWi...
[truncated]

Base automatically changed from users/meinersbur/clang_openmp_unroll-tile_foreach to main May 22, 2024 12:30
@Meinersbur
Copy link
Member Author

@alexey-bataev ping

@Meinersbur
Copy link
Member Author

#92916 has been accepted, but waiting for this PR.

Comment on lines 15920 to 15925
SmallVector<Stmt *> BodyStmts;
BodyStmts.push_back(InitReversed.get());
llvm::append_range(BodyStmts, LoopHelper.Updates);
if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt))
BodyStmts.push_back(CXXRangeFor->getLoopVarStmt());
BodyStmts.push_back(Body);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
SmallVector<Stmt *> BodyStmts;
BodyStmts.push_back(InitReversed.get());
llvm::append_range(BodyStmts, LoopHelper.Updates);
if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt))
BodyStmts.push_back(CXXRangeFor->getLoopVarStmt());
BodyStmts.push_back(Body);
SmallVector<Stmt *> BodyStmts(LoopHelper.Updates.size() + 2 + (isa<CXXForRangeStmt>(LoopStmt) ? 1 : 0));
BodyStmts.front() = InitReversed.get();
llvm::copy(LoopHelper.Updates, std::next(BodyStmts.begin());
if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt))
BodyStmts[LoopHelper.Updates.size() + 1] = CXXRangeFor->getLoopVarStmt();
BodyStmts.back() = Body;

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree with this change. It makes it harder to read, more difficult for size computation and indices to stay consistent when changed, and all for avoiding a possible reallocation that most will not happen anyway because it stays within the SmallVector's small size.

For my 64 bit system, small size is 6. LoopHelper.Updates.size() will be 1 (reverse applies to one loop, i.e. just one loop counter). That is, the size will be either 3 or 4.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It reduces memory fragmentation. Better to preallocate the buffer, if the size is known

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The entire vector is on the stack; there is no memory fragmentation. reserve() would also be able to do a pre-allocation without manual iterator calculation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is not not on a stack. It is on stack only in some preallocated (small) cases, otherwise it is dynamically allocated

Copy link
Member

@alexey-bataev alexey-bataev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG with a nit

@Meinersbur Meinersbur merged commit 80865c0 into main Jul 18, 2024
8 of 9 checks passed
@Meinersbur Meinersbur deleted the users/meinersbur/clang_openmp_reverse branch July 18, 2024 08:35
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
Summary:
Add the reverse directive which will be introduced in the upcoming
OpenMP 6.0 specification. A preview has been published in [Technical
Report 12](https://www.openmp.org/wp-content/uploads/openmp-TR12.pdf).

---------

Co-authored-by: Alexey Bataev <[email protected]>

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250833
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules clang:openmp OpenMP related changes to Clang clang Clang issues not falling into any other category flang:openmp openmp:libomp OpenMP host runtime
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants