Skip to content

[NFC][mlir][OpenMP] Remove mentions of target from generic loop rewrite #124528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 27, 2025

Conversation

ergawy
Copy link
Member

@ergawy ergawy commented Jan 27, 2025

This removes mentions of target from the generic loop rewrite pass
since there is not need for it anyway. It is enough to detect loop's
nesting within teams or parallel directives.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:openmp labels Jan 27, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 27, 2025

@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-flang-fir-hlfir

Author: Kareem Ergawy (ergawy)

Changes

This removes mentions of target from the generic loop rewrite pass
since there is not need for it anyway. It is enough to detect loop's
nesting within teams or parallel directives.


Full diff: https://github.com/llvm/llvm-project/pull/124528.diff

3 Files Affected:

  • (modified) flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp (+68-33)
  • (modified) flang/test/Lower/OpenMP/loop-directive.f90 (+41-1)
  • (modified) flang/test/Transforms/generic-loop-rewriting-todo.mlir (+1-1)
diff --git a/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp b/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
index 555601c5e92df6..36d6c5a4e6b3b2 100644
--- a/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
+++ b/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
@@ -29,11 +29,7 @@ namespace {
 class GenericLoopConversionPattern
     : public mlir::OpConversionPattern<mlir::omp::LoopOp> {
 public:
-  enum class GenericLoopCombinedInfo {
-    Standalone,
-    TargetTeamsLoop,
-    TargetParallelLoop
-  };
+  enum class GenericLoopCombinedInfo { Standalone, TeamsLoop, ParallelLoop };
 
   using mlir::OpConversionPattern<mlir::omp::LoopOp>::OpConversionPattern;
 
@@ -53,12 +49,12 @@ class GenericLoopConversionPattern
 
     switch (combinedInfo) {
     case GenericLoopCombinedInfo::Standalone:
-      rewriteToSimdLoop(loopOp, rewriter);
+      rewriteStandaloneLoop(loopOp, rewriter);
       break;
-    case GenericLoopCombinedInfo::TargetParallelLoop:
+    case GenericLoopCombinedInfo::ParallelLoop:
       llvm_unreachable("not yet implemented: `parallel loop` direcitve");
       break;
-    case GenericLoopCombinedInfo::TargetTeamsLoop:
+    case GenericLoopCombinedInfo::TeamsLoop:
       rewriteToDistributeParallelDo(loopOp, rewriter);
       break;
     }
@@ -74,10 +70,10 @@ class GenericLoopConversionPattern
     switch (combinedInfo) {
     case GenericLoopCombinedInfo::Standalone:
       break;
-    case GenericLoopCombinedInfo::TargetParallelLoop:
+    case GenericLoopCombinedInfo::ParallelLoop:
       return loopOp.emitError(
-          "not yet implemented: Combined `omp target parallel loop` directive");
-    case GenericLoopCombinedInfo::TargetTeamsLoop:
+          "not yet implemented: Combined `parallel loop` directive");
+    case GenericLoopCombinedInfo::TeamsLoop:
       break;
     }
 
@@ -87,7 +83,10 @@ class GenericLoopConversionPattern
              << loopOp->getName() << " operation";
     };
 
-    if (loopOp.getBindKind())
+    // For standalone directives, `bind` is already supported. Other combined
+    // forms will be supported in a follow-up PR.
+    if (combinedInfo != GenericLoopCombinedInfo::Standalone &&
+        loopOp.getBindKind())
       return todo("bind");
 
     if (loopOp.getOrder())
@@ -96,7 +95,7 @@ class GenericLoopConversionPattern
     if (!loopOp.getReductionVars().empty())
       return todo("reduction");
 
-    // TODO For `target teams loop`, check similar constrains to what is checked
+    // TODO For `teams loop`, check similar constrains to what is checked
     // by `TeamsLoopChecker` in SemaOpenMP.cpp.
     return mlir::success();
   }
@@ -108,18 +107,36 @@ class GenericLoopConversionPattern
     GenericLoopCombinedInfo result = GenericLoopCombinedInfo::Standalone;
 
     if (auto teamsOp = mlir::dyn_cast_if_present<mlir::omp::TeamsOp>(parentOp))
-      if (mlir::isa_and_present<mlir::omp::TargetOp>(teamsOp->getParentOp()))
-        result = GenericLoopCombinedInfo::TargetTeamsLoop;
+      result = GenericLoopCombinedInfo::TeamsLoop;
 
     if (auto parallelOp =
             mlir::dyn_cast_if_present<mlir::omp::ParallelOp>(parentOp))
-      if (mlir::isa_and_present<mlir::omp::TargetOp>(parallelOp->getParentOp()))
-        result = GenericLoopCombinedInfo::TargetParallelLoop;
+      result = GenericLoopCombinedInfo::ParallelLoop;
 
     return result;
   }
 
-  /// Rewrites standalone `loop` directives to equivalent `simd` constructs.
+  void rewriteStandaloneLoop(mlir::omp::LoopOp loopOp,
+                             mlir::ConversionPatternRewriter &rewriter) const {
+    using namespace mlir::omp;
+    std::optional<ClauseBindKind> bindKind = loopOp.getBindKind();
+
+    if (!bindKind.has_value())
+      return rewriteToSimdLoop(loopOp, rewriter);
+
+    switch (*loopOp.getBindKind()) {
+    case ClauseBindKind::Parallel:
+      return rewriteToWsloop(loopOp, rewriter);
+    case ClauseBindKind::Teams:
+      return rewriteToDistrbute(loopOp, rewriter);
+    case ClauseBindKind::Thread:
+      return rewriteToSimdLoop(loopOp, rewriter);
+    }
+  }
+
+  /// Rewrites standalone `loop` (without `bind` clause or with
+  /// `bind(parallel)`) directives to equivalent `simd` constructs.
+  ///
   /// The reasoning behind this decision is that according to the spec (version
   /// 5.2, section 11.7.1):
   ///
@@ -147,30 +164,48 @@ class GenericLoopConversionPattern
   /// the directive.
   void rewriteToSimdLoop(mlir::omp::LoopOp loopOp,
                          mlir::ConversionPatternRewriter &rewriter) const {
-    loopOp.emitWarning("Detected standalone OpenMP `loop` directive, the "
-                       "associated loop will be rewritten to `simd`.");
-    mlir::omp::SimdOperands simdClauseOps;
-    simdClauseOps.privateVars = loopOp.getPrivateVars();
+    loopOp.emitWarning(
+        "Detected standalone OpenMP `loop` directive with thread binding, "
+        "the associated loop will be rewritten to `simd`.");
+    rewriteToSingleWrapperOp<mlir::omp::SimdOp, mlir::omp::SimdOperands>(
+        loopOp, rewriter);
+  }
+
+  void rewriteToDistrbute(mlir::omp::LoopOp loopOp,
+                          mlir::ConversionPatternRewriter &rewriter) const {
+    rewriteToSingleWrapperOp<mlir::omp::DistributeOp,
+                             mlir::omp::DistributeOperands>(loopOp, rewriter);
+  }
+
+  void rewriteToWsloop(mlir::omp::LoopOp loopOp,
+                       mlir::ConversionPatternRewriter &rewriter) const {
+    rewriteToSingleWrapperOp<mlir::omp::WsloopOp, mlir::omp::WsloopOperands>(
+        loopOp, rewriter);
+  }
+
+  template <typename OpTy, typename OpOperandsTy>
+  void
+  rewriteToSingleWrapperOp(mlir::omp::LoopOp loopOp,
+                           mlir::ConversionPatternRewriter &rewriter) const {
+    OpOperandsTy clauseOps;
+    clauseOps.privateVars = loopOp.getPrivateVars();
 
     auto privateSyms = loopOp.getPrivateSyms();
     if (privateSyms)
-      simdClauseOps.privateSyms.assign(privateSyms->begin(),
-                                       privateSyms->end());
+      clauseOps.privateSyms.assign(privateSyms->begin(), privateSyms->end());
 
-    Fortran::common::openmp::EntryBlockArgs simdArgs;
-    simdArgs.priv.vars = simdClauseOps.privateVars;
+    Fortran::common::openmp::EntryBlockArgs args;
+    args.priv.vars = clauseOps.privateVars;
 
-    auto simdOp =
-        rewriter.create<mlir::omp::SimdOp>(loopOp.getLoc(), simdClauseOps);
-    mlir::Block *simdBlock =
-        genEntryBlock(rewriter, simdArgs, simdOp.getRegion());
+    auto wrapperOp = rewriter.create<OpTy>(loopOp.getLoc(), clauseOps);
+    mlir::Block *opBlock = genEntryBlock(rewriter, args, wrapperOp.getRegion());
 
     mlir::IRMapping mapper;
     mlir::Block &loopBlock = *loopOp.getRegion().begin();
 
-    for (auto [loopOpArg, simdopArg] :
-         llvm::zip_equal(loopBlock.getArguments(), simdBlock->getArguments()))
-      mapper.map(loopOpArg, simdopArg);
+    for (auto [loopOpArg, opArg] :
+         llvm::zip_equal(loopBlock.getArguments(), opBlock->getArguments()))
+      mapper.map(loopOpArg, opArg);
 
     rewriter.clone(*loopOp.begin(), mapper);
   }
diff --git a/flang/test/Lower/OpenMP/loop-directive.f90 b/flang/test/Lower/OpenMP/loop-directive.f90
index 9fa0de3bfe171a..845905da0fcba2 100644
--- a/flang/test/Lower/OpenMP/loop-directive.f90
+++ b/flang/test/Lower/OpenMP/loop-directive.f90
@@ -92,7 +92,7 @@ subroutine test_reduction()
 ! CHECK-LABEL: func.func @_QPtest_bind
 subroutine test_bind()
   integer :: i, dummy = 1
-  ! CHECK: omp.loop bind(thread) private(@{{.*}} %{{.*}}#0 -> %{{.*}} : {{.*}}) {
+  ! CHECK: omp.simd private(@{{.*}} %{{.*}}#0 -> %{{.*}} : {{.*}}) {
   ! CHECK: }
   !$omp loop bind(thread)
   do i=1,10
@@ -139,3 +139,43 @@ subroutine test_nested_directives
   end do
   !$omp end target teams
 end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_standalone_bind_teams
+subroutine test_standalone_bind_teams
+  implicit none
+  integer, parameter :: N = 100000
+  integer a(N), b(N), c(N)
+  integer j,i, num, flag;
+  num = N
+
+  ! CHECK:     omp.distribute
+  ! CHECK-SAME:  private(@{{.*}}Ea_private_ref_100000xi32 {{[^,]*}},
+  ! CHECK-SAME:          @{{.*}}Ei_private_ref_i32 {{.*}} : {{.*}}) {
+  ! CHECK:       omp.loop_nest {{.*}} {
+  ! CHECK:       }
+  ! CHECK:     }
+  !$omp loop bind(teams) private(a)
+  do i=1,N
+    c(i) = a(i) * b(i)
+  end do
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_standalone_bind_parallel
+subroutine test_standalone_bind_parallel
+  implicit none
+  integer, parameter :: N = 100000
+  integer a(N), b(N), c(N)
+  integer j,i, num, flag;
+  num = N
+
+  ! CHECK:     omp.wsloop
+  ! CHECK-SAME:  private(@{{.*}}Ea_private_ref_100000xi32 {{[^,]*}},
+  ! CHECK-SAME:          @{{.*}}Ei_private_ref_i32 {{.*}} : {{.*}}) {
+  ! CHECK:       omp.loop_nest {{.*}} {
+  ! CHECK:       }
+  ! CHECK:     }
+  !$omp loop bind(parallel) private(a)
+  do i=1,N
+    c(i) = a(i) * b(i)
+  end do
+end subroutine
diff --git a/flang/test/Transforms/generic-loop-rewriting-todo.mlir b/flang/test/Transforms/generic-loop-rewriting-todo.mlir
index becd6b8dcb5cb4..3259ceca70d50d 100644
--- a/flang/test/Transforms/generic-loop-rewriting-todo.mlir
+++ b/flang/test/Transforms/generic-loop-rewriting-todo.mlir
@@ -6,7 +6,7 @@ func.func @_QPtarget_parallel_loop() {
       %c0 = arith.constant 0 : i32
       %c10 = arith.constant 10 : i32
       %c1 = arith.constant 1 : i32
-      // expected-error@below {{not yet implemented: Combined `omp target parallel loop` directive}}
+      // expected-error@below {{not yet implemented: Combined `parallel loop` directive}}
       omp.loop {
         omp.loop_nest (%arg3) : i32 = (%c0) to (%c10) inclusive step (%c1) {
           omp.yield

Copy link
Member

@skatrak skatrak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, this LGTM.

Since the pass is unrelated to target, I think it would make sense to also update the "generic-loop-rewriting[-todo].mlir" tests to simplify them removing superfluous omp.target ops.

@ergawy ergawy force-pushed the remove_target_from_loop_rewrite branch 2 times, most recently from 7c4ec9b to 00cba1e Compare January 27, 2025 14:38
Copy link

github-actions bot commented Jan 27, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

…rite

This removes mentions of `target` from the generic `loop` rewrite pass
since there is not need for it anyway. It is enough to detect `loop`'s
nesting within `teams` or `parallel` directives.
@ergawy ergawy force-pushed the remove_target_from_loop_rewrite branch from 00cba1e to 2142211 Compare January 27, 2025 14:45
@ergawy ergawy merged commit 1e2d5f7 into llvm:main Jan 27, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants