Skip to content

[MLIR][OpenMP] Improve loop wrapper representation #97706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2024

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Jul 4, 2024

This patch replaces the SingleBlockImplicitTerminator<"TerminatorOp"> trait of loop wrapper operations for the SingleBlock trait. This enables a more robust implementation of the LoopWrapperInterface::isWrapper() method, since it does no longer have to deal with the potentially missing (implicit) terminator.

The LoopWrapperInterface::isWrapper() method is also extended to not identify as wrappers those operations which have a loop wrapper operation inside that is not taking a wrapper role. This is important for cases where omp.parallel is nested, which can but is not required to work as a loop wrapper.

Tests are updated to integrate these representation and validation changes.

@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2024

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Sergio Afonso (skatrak)

Changes

This patch replaces the SingleBlockImplicitTerminator&lt;"TerminatorOp"&gt; trait of loop wrapper operations for the SingleBlock trait. This enables a more robust implementation of the LoopWrapperInterface::isWrapper() method, since it does no longer have to deal with the potentially missing (implicit) terminator.

The LoopWrapperInterface::isWrapper() method is also extended to not identify as wrappers those operations which have a loop wrapper operation inside that is not taking a wrapper role. This is important for cases where omp.parallel is nested, which can but is not required to work as a loop wrapper.

Tests are updated to integrate these representation and validation changes.


Full diff: https://github.com/llvm/llvm-project/pull/97706.diff

7 Files Affected:

  • (modified) flang/test/Fir/convert-to-llvm-openmp-and-fir.fir (+4)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+4-4)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+10-4)
  • (modified) mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir (+1)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+11-4)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+35)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+6)
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index 8b62787bb3094..eca762d52a724 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -200,6 +200,7 @@ func.func @_QPsimd1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref
         fir.store %3 to %6 : !fir.ref<i32>
         omp.yield
       }
+      omp.terminator
     }
     omp.terminator
   }
@@ -225,6 +226,7 @@ func.func @_QPsimd1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref
 // CHECK:   llvm.store %[[I1]], %[[ARR_I_REF]] : i32, !llvm.ptr
 // CHECK: omp.yield
 // CHECK: }
+// CHECK: omp.terminator
 // CHECK: }
 // CHECK: omp.terminator
 // CHECK: }
@@ -518,6 +520,7 @@ func.func @_QPsimd_with_nested_loop() {
       fir.store %7 to %3 : !fir.ref<i32>
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -538,6 +541,7 @@ func.func @_QPsimd_with_nested_loop() {
 // CHECK:             ^bb3:
 // CHECK:               omp.yield
 // CHECK:             }
+// CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           llvm.return
 // CHECK:         }
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 99e14cd1b7b48..aed0d69619db2 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -354,7 +354,7 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
 
 def WsloopOp : OpenMP_Op<"wsloop", traits = [
     AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
-    RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
+    RecursiveMemoryEffects, SingleBlock
   ], clauses = [
     // TODO: Complete clause list (allocate, private).
     // TODO: Sort clauses alphabetically.
@@ -418,7 +418,7 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
 
 def SimdOp : OpenMP_Op<"simd", traits = [
     AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
-    RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
+    RecursiveMemoryEffects, SingleBlock
   ], clauses = [
     // TODO: Complete clause list (linear, private, reduction).
     OpenMP_AlignedClause, OpenMP_IfClause, OpenMP_NontemporalClause,
@@ -485,7 +485,7 @@ def YieldOp : OpenMP_Op<"yield",
 //===----------------------------------------------------------------------===//
 def DistributeOp : OpenMP_Op<"distribute", traits = [
     AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
-    RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
+    RecursiveMemoryEffects, SingleBlock
   ], clauses = [
     // TODO: Complete clause list (private).
     // TODO: Sort clauses alphabetically.
@@ -575,7 +575,7 @@ def TaskOp : OpenMP_Op<"task", traits = [
 def TaskloopOp : OpenMP_Op<"taskloop", traits = [
     AttrSizedOperandSegments, AutomaticAllocationScope,
     DeclareOpInterfaceMethods<LoopWrapperInterface>, RecursiveMemoryEffects,
-    SingleBlockImplicitTerminator<"TerminatorOp">
+    SingleBlock
   ], clauses = [
     // TODO: Complete clause list (private).
     // TODO: Sort clauses alphabetically.
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 31a306072d0ec..385aa8b1b016a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -84,8 +84,8 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
       /*description=*/[{
         Tell whether the operation could be taking the role of a loop wrapper.
         That is, it has a single region with a single block in which there are
-        two operations: another wrapper or `omp.loop_nest` operation and a
-        terminator.
+        two operations: another wrapper (also taking a loop wrapper role) or
+        `omp.loop_nest` operation and a terminator.
       }],
       /*retTy=*/"bool",
       /*methodName=*/"isWrapper",
@@ -102,8 +102,14 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
 
         Operation &firstOp = *r.op_begin();
         Operation &secondOp = *(std::next(r.op_begin()));
-        return ::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp) &&
-               secondOp.hasTrait<OpTrait::IsTerminator>();
+
+        if (!secondOp.hasTrait<OpTrait::IsTerminator>())
+          return false;
+
+        if (auto wrapper = ::llvm::dyn_cast<LoopWrapperInterface>(firstOp))
+          return wrapper.isWrapper();
+
+        return ::llvm::isa<LoopNestOp>(firstOp);
       }]
     >,
     InterfaceMethod<
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index 3aeb9e70522d5..4c9e09970279a 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -174,6 +174,7 @@ func.func @loop_nest_block_arg(%val : i32, %ub : i32, %i : index) {
     ^bb3:
       omp.yield
     }
+    omp.terminator
   }
   return
 }
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 2915963f704d3..91eeb0911160d 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -11,8 +11,8 @@ func.func @unknown_clause() {
 // -----
 
 func.func @not_wrapper() {
+  // expected-error@+1 {{op must be a loop wrapper}}
   omp.distribute {
-    // expected-error@+1 {{op must take a loop wrapper role if nested inside of 'omp.distribute'}}
     omp.parallel {
       %0 = arith.constant 0 : i32
       omp.terminator
@@ -383,12 +383,16 @@ func.func @omp_simd() -> () {
 
 // -----
 
-func.func @omp_simd_nested_wrapper() -> () {
+func.func @omp_simd_nested_wrapper(%lb : index, %ub : index, %step : index) -> () {
   // expected-error @below {{op must wrap an 'omp.loop_nest' directly}}
   omp.simd {
     omp.distribute {
+      omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+        omp.yield
+      }
       omp.terminator
     }
+    omp.terminator
   }
   return
 }
@@ -1960,6 +1964,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
       }
       omp.terminator
     }
+    omp.terminator
   }
   return
 }
@@ -2158,11 +2163,13 @@ func.func @omp_distribute_wrapper() -> () {
 
 // -----
 
-func.func @omp_distribute_nested_wrapper(%data_var : memref<i32>) -> () {
+func.func @omp_distribute_nested_wrapper(%lb: index, %ub: index, %step: index) -> () {
   // expected-error @below {{only supported nested wrappers are 'omp.parallel' and 'omp.simd'}}
   omp.distribute {
     "omp.wsloop"() ({
-      %0 = arith.constant 0 : i32
+      omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+        "omp.yield"() : () -> ()
+      }
       "omp.terminator"() : () -> ()
     }) : () -> ()
     "omp.terminator"() : () -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index eb283840aa7ee..ff3b1e60f7cfe 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -617,6 +617,7 @@ func.func @omp_simd_pretty(%lb : index, %ub : index, %step : index) -> () {
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -632,6 +633,7 @@ func.func @omp_simd_pretty_aligned(%lb : index, %ub : index, %step : index,
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -643,6 +645,7 @@ func.func @omp_simd_pretty_if(%lb : index, %ub : index, %step : index, %if_cond
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -656,6 +659,7 @@ func.func @omp_simd_pretty_nontemporal(%lb : index, %ub : index, %step : index,
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -667,18 +671,21 @@ func.func @omp_simd_pretty_order(%lb : index, %ub : index, %step : index) -> ()
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.simd order(reproducible:concurrent)
   omp.simd order(reproducible:concurrent) {
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.simd order(unconstrained:concurrent)
   omp.simd order(unconstrained:concurrent) {
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -690,6 +697,7 @@ func.func @omp_simd_pretty_simdlen(%lb : index, %ub : index, %step : index) -> (
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -701,6 +709,7 @@ func.func @omp_simd_pretty_safelen(%lb : index, %ub : index, %step : index) -> (
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -720,42 +729,49 @@ func.func @omp_distribute(%chunk_size : i32, %data_var : memref<i32>, %arg0 : i3
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute dist_schedule_static
   omp.distribute dist_schedule_static {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute dist_schedule_static chunk_size(%{{.+}} : i32)
   omp.distribute dist_schedule_static chunk_size(%chunk_size : i32) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute order(concurrent)
   omp.distribute order(concurrent) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute order(reproducible:concurrent)
   omp.distribute order(reproducible:concurrent) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute order(unconstrained:concurrent)
   omp.distribute order(unconstrained:concurrent) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute allocate(%{{.+}} : memref<i32> -> %{{.+}} : memref<i32>)
   omp.distribute allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute
   omp.distribute {
@@ -763,7 +779,9 @@ func.func @omp_distribute(%chunk_size : i32, %data_var : memref<i32>, %arg0 : i3
       omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
         omp.yield
       }
+      omp.terminator
     }
+    omp.terminator
   }
   return
 }
@@ -2278,6 +2296,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testbool = "test.bool"() : () -> (i1)
@@ -2288,6 +2307,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop final(%{{[^)]+}}) {
@@ -2296,6 +2316,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop untied {
@@ -2304,6 +2325,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop mergeable {
@@ -2312,6 +2334,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testf32 = "test.f32"() : () -> (!llvm.ptr)
@@ -2322,6 +2345,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // Checking byref attribute for in_reduction
@@ -2331,6 +2355,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop reduction(byref @add_f32 -> %{{.+}} : !llvm.ptr, @add_f32 -> %{{.+}} : !llvm.ptr) {
@@ -2339,6 +2364,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // check byref attrbute for reduction
@@ -2348,6 +2374,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop in_reduction(@add_f32 -> %{{.+}} : !llvm.ptr) reduction(@add_f32 -> %{{.+}} : !llvm.ptr) {
@@ -2356,6 +2383,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testi32 = "test.i32"() : () -> (i32)
@@ -2365,6 +2393,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testmemref = "test.memref"() : () -> (memref<i32>)
@@ -2374,6 +2403,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testi64 = "test.i64"() : () -> (i64)
@@ -2383,6 +2413,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop num_tasks(%{{[^:]+}}: i64) {
@@ -2391,6 +2422,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop nogroup {
@@ -2399,6 +2431,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop {
@@ -2408,7 +2441,9 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
         // CHECK: omp.yield
         omp.yield
       }
+      omp.terminator
     }
+    omp.terminator
   }
 
   // CHECK: return
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 321de67aa48a1..dfeaf4be33adb 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -726,6 +726,7 @@ llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64
       llvm.store %3, %5 : f32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -749,6 +750,7 @@ llvm.func @simd_simple_multiple_simdlen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l
       llvm.store %3, %5 : f32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -769,6 +771,7 @@ llvm.func @simd_simple_multiple_safelen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l
       llvm.store %3, %5 : f32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -788,6 +791,7 @@ llvm.func @simd_simple_multiple_simdlen_safelen(%lb1 : i64, %ub1 : i64, %step1 :
       llvm.store %3, %5 : f32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -816,6 +820,7 @@ llvm.func @simd_if(%arg0: !llvm.ptr {fir.bindc_name = "n"}, %arg1: !llvm.ptr {fi
       llvm.store %arg2, %1 : i32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -836,6 +841,7 @@ llvm.func @simd_order() {
       llvm.store %arg0, %2 : i64, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }

@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2024

@llvm/pr-subscribers-mlir-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch replaces the SingleBlockImplicitTerminator&lt;"TerminatorOp"&gt; trait of loop wrapper operations for the SingleBlock trait. This enables a more robust implementation of the LoopWrapperInterface::isWrapper() method, since it does no longer have to deal with the potentially missing (implicit) terminator.

The LoopWrapperInterface::isWrapper() method is also extended to not identify as wrappers those operations which have a loop wrapper operation inside that is not taking a wrapper role. This is important for cases where omp.parallel is nested, which can but is not required to work as a loop wrapper.

Tests are updated to integrate these representation and validation changes.


Full diff: https://github.com/llvm/llvm-project/pull/97706.diff

7 Files Affected:

  • (modified) flang/test/Fir/convert-to-llvm-openmp-and-fir.fir (+4)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+4-4)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+10-4)
  • (modified) mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir (+1)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+11-4)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+35)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+6)
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index 8b62787bb3094..eca762d52a724 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -200,6 +200,7 @@ func.func @_QPsimd1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref
         fir.store %3 to %6 : !fir.ref<i32>
         omp.yield
       }
+      omp.terminator
     }
     omp.terminator
   }
@@ -225,6 +226,7 @@ func.func @_QPsimd1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref
 // CHECK:   llvm.store %[[I1]], %[[ARR_I_REF]] : i32, !llvm.ptr
 // CHECK: omp.yield
 // CHECK: }
+// CHECK: omp.terminator
 // CHECK: }
 // CHECK: omp.terminator
 // CHECK: }
@@ -518,6 +520,7 @@ func.func @_QPsimd_with_nested_loop() {
       fir.store %7 to %3 : !fir.ref<i32>
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -538,6 +541,7 @@ func.func @_QPsimd_with_nested_loop() {
 // CHECK:             ^bb3:
 // CHECK:               omp.yield
 // CHECK:             }
+// CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           llvm.return
 // CHECK:         }
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 99e14cd1b7b48..aed0d69619db2 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -354,7 +354,7 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
 
 def WsloopOp : OpenMP_Op<"wsloop", traits = [
     AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
-    RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
+    RecursiveMemoryEffects, SingleBlock
   ], clauses = [
     // TODO: Complete clause list (allocate, private).
     // TODO: Sort clauses alphabetically.
@@ -418,7 +418,7 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
 
 def SimdOp : OpenMP_Op<"simd", traits = [
     AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
-    RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
+    RecursiveMemoryEffects, SingleBlock
   ], clauses = [
     // TODO: Complete clause list (linear, private, reduction).
     OpenMP_AlignedClause, OpenMP_IfClause, OpenMP_NontemporalClause,
@@ -485,7 +485,7 @@ def YieldOp : OpenMP_Op<"yield",
 //===----------------------------------------------------------------------===//
 def DistributeOp : OpenMP_Op<"distribute", traits = [
     AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
-    RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
+    RecursiveMemoryEffects, SingleBlock
   ], clauses = [
     // TODO: Complete clause list (private).
     // TODO: Sort clauses alphabetically.
@@ -575,7 +575,7 @@ def TaskOp : OpenMP_Op<"task", traits = [
 def TaskloopOp : OpenMP_Op<"taskloop", traits = [
     AttrSizedOperandSegments, AutomaticAllocationScope,
     DeclareOpInterfaceMethods<LoopWrapperInterface>, RecursiveMemoryEffects,
-    SingleBlockImplicitTerminator<"TerminatorOp">
+    SingleBlock
   ], clauses = [
     // TODO: Complete clause list (private).
     // TODO: Sort clauses alphabetically.
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 31a306072d0ec..385aa8b1b016a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -84,8 +84,8 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
       /*description=*/[{
         Tell whether the operation could be taking the role of a loop wrapper.
         That is, it has a single region with a single block in which there are
-        two operations: another wrapper or `omp.loop_nest` operation and a
-        terminator.
+        two operations: another wrapper (also taking a loop wrapper role) or
+        `omp.loop_nest` operation and a terminator.
       }],
       /*retTy=*/"bool",
       /*methodName=*/"isWrapper",
@@ -102,8 +102,14 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
 
         Operation &firstOp = *r.op_begin();
         Operation &secondOp = *(std::next(r.op_begin()));
-        return ::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp) &&
-               secondOp.hasTrait<OpTrait::IsTerminator>();
+
+        if (!secondOp.hasTrait<OpTrait::IsTerminator>())
+          return false;
+
+        if (auto wrapper = ::llvm::dyn_cast<LoopWrapperInterface>(firstOp))
+          return wrapper.isWrapper();
+
+        return ::llvm::isa<LoopNestOp>(firstOp);
       }]
     >,
     InterfaceMethod<
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index 3aeb9e70522d5..4c9e09970279a 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -174,6 +174,7 @@ func.func @loop_nest_block_arg(%val : i32, %ub : i32, %i : index) {
     ^bb3:
       omp.yield
     }
+    omp.terminator
   }
   return
 }
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 2915963f704d3..91eeb0911160d 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -11,8 +11,8 @@ func.func @unknown_clause() {
 // -----
 
 func.func @not_wrapper() {
+  // expected-error@+1 {{op must be a loop wrapper}}
   omp.distribute {
-    // expected-error@+1 {{op must take a loop wrapper role if nested inside of 'omp.distribute'}}
     omp.parallel {
       %0 = arith.constant 0 : i32
       omp.terminator
@@ -383,12 +383,16 @@ func.func @omp_simd() -> () {
 
 // -----
 
-func.func @omp_simd_nested_wrapper() -> () {
+func.func @omp_simd_nested_wrapper(%lb : index, %ub : index, %step : index) -> () {
   // expected-error @below {{op must wrap an 'omp.loop_nest' directly}}
   omp.simd {
     omp.distribute {
+      omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+        omp.yield
+      }
       omp.terminator
     }
+    omp.terminator
   }
   return
 }
@@ -1960,6 +1964,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
       }
       omp.terminator
     }
+    omp.terminator
   }
   return
 }
@@ -2158,11 +2163,13 @@ func.func @omp_distribute_wrapper() -> () {
 
 // -----
 
-func.func @omp_distribute_nested_wrapper(%data_var : memref<i32>) -> () {
+func.func @omp_distribute_nested_wrapper(%lb: index, %ub: index, %step: index) -> () {
   // expected-error @below {{only supported nested wrappers are 'omp.parallel' and 'omp.simd'}}
   omp.distribute {
     "omp.wsloop"() ({
-      %0 = arith.constant 0 : i32
+      omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+        "omp.yield"() : () -> ()
+      }
       "omp.terminator"() : () -> ()
     }) : () -> ()
     "omp.terminator"() : () -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index eb283840aa7ee..ff3b1e60f7cfe 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -617,6 +617,7 @@ func.func @omp_simd_pretty(%lb : index, %ub : index, %step : index) -> () {
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -632,6 +633,7 @@ func.func @omp_simd_pretty_aligned(%lb : index, %ub : index, %step : index,
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -643,6 +645,7 @@ func.func @omp_simd_pretty_if(%lb : index, %ub : index, %step : index, %if_cond
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -656,6 +659,7 @@ func.func @omp_simd_pretty_nontemporal(%lb : index, %ub : index, %step : index,
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -667,18 +671,21 @@ func.func @omp_simd_pretty_order(%lb : index, %ub : index, %step : index) -> ()
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.simd order(reproducible:concurrent)
   omp.simd order(reproducible:concurrent) {
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.simd order(unconstrained:concurrent)
   omp.simd order(unconstrained:concurrent) {
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -690,6 +697,7 @@ func.func @omp_simd_pretty_simdlen(%lb : index, %ub : index, %step : index) -> (
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -701,6 +709,7 @@ func.func @omp_simd_pretty_safelen(%lb : index, %ub : index, %step : index) -> (
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -720,42 +729,49 @@ func.func @omp_distribute(%chunk_size : i32, %data_var : memref<i32>, %arg0 : i3
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute dist_schedule_static
   omp.distribute dist_schedule_static {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute dist_schedule_static chunk_size(%{{.+}} : i32)
   omp.distribute dist_schedule_static chunk_size(%chunk_size : i32) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute order(concurrent)
   omp.distribute order(concurrent) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute order(reproducible:concurrent)
   omp.distribute order(reproducible:concurrent) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute order(unconstrained:concurrent)
   omp.distribute order(unconstrained:concurrent) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute allocate(%{{.+}} : memref<i32> -> %{{.+}} : memref<i32>)
   omp.distribute allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute
   omp.distribute {
@@ -763,7 +779,9 @@ func.func @omp_distribute(%chunk_size : i32, %data_var : memref<i32>, %arg0 : i3
       omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
         omp.yield
       }
+      omp.terminator
     }
+    omp.terminator
   }
   return
 }
@@ -2278,6 +2296,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testbool = "test.bool"() : () -> (i1)
@@ -2288,6 +2307,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop final(%{{[^)]+}}) {
@@ -2296,6 +2316,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop untied {
@@ -2304,6 +2325,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop mergeable {
@@ -2312,6 +2334,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testf32 = "test.f32"() : () -> (!llvm.ptr)
@@ -2322,6 +2345,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // Checking byref attribute for in_reduction
@@ -2331,6 +2355,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop reduction(byref @add_f32 -> %{{.+}} : !llvm.ptr, @add_f32 -> %{{.+}} : !llvm.ptr) {
@@ -2339,6 +2364,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // check byref attrbute for reduction
@@ -2348,6 +2374,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop in_reduction(@add_f32 -> %{{.+}} : !llvm.ptr) reduction(@add_f32 -> %{{.+}} : !llvm.ptr) {
@@ -2356,6 +2383,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testi32 = "test.i32"() : () -> (i32)
@@ -2365,6 +2393,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testmemref = "test.memref"() : () -> (memref<i32>)
@@ -2374,6 +2403,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testi64 = "test.i64"() : () -> (i64)
@@ -2383,6 +2413,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop num_tasks(%{{[^:]+}}: i64) {
@@ -2391,6 +2422,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop nogroup {
@@ -2399,6 +2431,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop {
@@ -2408,7 +2441,9 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
         // CHECK: omp.yield
         omp.yield
       }
+      omp.terminator
     }
+    omp.terminator
   }
 
   // CHECK: return
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 321de67aa48a1..dfeaf4be33adb 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -726,6 +726,7 @@ llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64
       llvm.store %3, %5 : f32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -749,6 +750,7 @@ llvm.func @simd_simple_multiple_simdlen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l
       llvm.store %3, %5 : f32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -769,6 +771,7 @@ llvm.func @simd_simple_multiple_safelen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l
       llvm.store %3, %5 : f32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -788,6 +791,7 @@ llvm.func @simd_simple_multiple_simdlen_safelen(%lb1 : i64, %ub1 : i64, %step1 :
       llvm.store %3, %5 : f32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -816,6 +820,7 @@ llvm.func @simd_if(%arg0: !llvm.ptr {fir.bindc_name = "n"}, %arg1: !llvm.ptr {fi
       llvm.store %arg2, %1 : i32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -836,6 +841,7 @@ llvm.func @simd_order() {
       llvm.store %arg0, %2 : i64, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }

@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2024

@llvm/pr-subscribers-flang-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch replaces the SingleBlockImplicitTerminator&lt;"TerminatorOp"&gt; trait of loop wrapper operations for the SingleBlock trait. This enables a more robust implementation of the LoopWrapperInterface::isWrapper() method, since it does no longer have to deal with the potentially missing (implicit) terminator.

The LoopWrapperInterface::isWrapper() method is also extended to not identify as wrappers those operations which have a loop wrapper operation inside that is not taking a wrapper role. This is important for cases where omp.parallel is nested, which can but is not required to work as a loop wrapper.

Tests are updated to integrate these representation and validation changes.


Full diff: https://github.com/llvm/llvm-project/pull/97706.diff

7 Files Affected:

  • (modified) flang/test/Fir/convert-to-llvm-openmp-and-fir.fir (+4)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+4-4)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+10-4)
  • (modified) mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir (+1)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+11-4)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+35)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+6)
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index 8b62787bb3094..eca762d52a724 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -200,6 +200,7 @@ func.func @_QPsimd1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref
         fir.store %3 to %6 : !fir.ref<i32>
         omp.yield
       }
+      omp.terminator
     }
     omp.terminator
   }
@@ -225,6 +226,7 @@ func.func @_QPsimd1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref
 // CHECK:   llvm.store %[[I1]], %[[ARR_I_REF]] : i32, !llvm.ptr
 // CHECK: omp.yield
 // CHECK: }
+// CHECK: omp.terminator
 // CHECK: }
 // CHECK: omp.terminator
 // CHECK: }
@@ -518,6 +520,7 @@ func.func @_QPsimd_with_nested_loop() {
       fir.store %7 to %3 : !fir.ref<i32>
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -538,6 +541,7 @@ func.func @_QPsimd_with_nested_loop() {
 // CHECK:             ^bb3:
 // CHECK:               omp.yield
 // CHECK:             }
+// CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           llvm.return
 // CHECK:         }
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 99e14cd1b7b48..aed0d69619db2 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -354,7 +354,7 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
 
 def WsloopOp : OpenMP_Op<"wsloop", traits = [
     AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
-    RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
+    RecursiveMemoryEffects, SingleBlock
   ], clauses = [
     // TODO: Complete clause list (allocate, private).
     // TODO: Sort clauses alphabetically.
@@ -418,7 +418,7 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
 
 def SimdOp : OpenMP_Op<"simd", traits = [
     AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
-    RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
+    RecursiveMemoryEffects, SingleBlock
   ], clauses = [
     // TODO: Complete clause list (linear, private, reduction).
     OpenMP_AlignedClause, OpenMP_IfClause, OpenMP_NontemporalClause,
@@ -485,7 +485,7 @@ def YieldOp : OpenMP_Op<"yield",
 //===----------------------------------------------------------------------===//
 def DistributeOp : OpenMP_Op<"distribute", traits = [
     AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
-    RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
+    RecursiveMemoryEffects, SingleBlock
   ], clauses = [
     // TODO: Complete clause list (private).
     // TODO: Sort clauses alphabetically.
@@ -575,7 +575,7 @@ def TaskOp : OpenMP_Op<"task", traits = [
 def TaskloopOp : OpenMP_Op<"taskloop", traits = [
     AttrSizedOperandSegments, AutomaticAllocationScope,
     DeclareOpInterfaceMethods<LoopWrapperInterface>, RecursiveMemoryEffects,
-    SingleBlockImplicitTerminator<"TerminatorOp">
+    SingleBlock
   ], clauses = [
     // TODO: Complete clause list (private).
     // TODO: Sort clauses alphabetically.
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 31a306072d0ec..385aa8b1b016a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -84,8 +84,8 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
       /*description=*/[{
         Tell whether the operation could be taking the role of a loop wrapper.
         That is, it has a single region with a single block in which there are
-        two operations: another wrapper or `omp.loop_nest` operation and a
-        terminator.
+        two operations: another wrapper (also taking a loop wrapper role) or
+        `omp.loop_nest` operation and a terminator.
       }],
       /*retTy=*/"bool",
       /*methodName=*/"isWrapper",
@@ -102,8 +102,14 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
 
         Operation &firstOp = *r.op_begin();
         Operation &secondOp = *(std::next(r.op_begin()));
-        return ::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp) &&
-               secondOp.hasTrait<OpTrait::IsTerminator>();
+
+        if (!secondOp.hasTrait<OpTrait::IsTerminator>())
+          return false;
+
+        if (auto wrapper = ::llvm::dyn_cast<LoopWrapperInterface>(firstOp))
+          return wrapper.isWrapper();
+
+        return ::llvm::isa<LoopNestOp>(firstOp);
       }]
     >,
     InterfaceMethod<
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index 3aeb9e70522d5..4c9e09970279a 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -174,6 +174,7 @@ func.func @loop_nest_block_arg(%val : i32, %ub : i32, %i : index) {
     ^bb3:
       omp.yield
     }
+    omp.terminator
   }
   return
 }
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 2915963f704d3..91eeb0911160d 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -11,8 +11,8 @@ func.func @unknown_clause() {
 // -----
 
 func.func @not_wrapper() {
+  // expected-error@+1 {{op must be a loop wrapper}}
   omp.distribute {
-    // expected-error@+1 {{op must take a loop wrapper role if nested inside of 'omp.distribute'}}
     omp.parallel {
       %0 = arith.constant 0 : i32
       omp.terminator
@@ -383,12 +383,16 @@ func.func @omp_simd() -> () {
 
 // -----
 
-func.func @omp_simd_nested_wrapper() -> () {
+func.func @omp_simd_nested_wrapper(%lb : index, %ub : index, %step : index) -> () {
   // expected-error @below {{op must wrap an 'omp.loop_nest' directly}}
   omp.simd {
     omp.distribute {
+      omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+        omp.yield
+      }
       omp.terminator
     }
+    omp.terminator
   }
   return
 }
@@ -1960,6 +1964,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
       }
       omp.terminator
     }
+    omp.terminator
   }
   return
 }
@@ -2158,11 +2163,13 @@ func.func @omp_distribute_wrapper() -> () {
 
 // -----
 
-func.func @omp_distribute_nested_wrapper(%data_var : memref<i32>) -> () {
+func.func @omp_distribute_nested_wrapper(%lb: index, %ub: index, %step: index) -> () {
   // expected-error @below {{only supported nested wrappers are 'omp.parallel' and 'omp.simd'}}
   omp.distribute {
     "omp.wsloop"() ({
-      %0 = arith.constant 0 : i32
+      omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+        "omp.yield"() : () -> ()
+      }
       "omp.terminator"() : () -> ()
     }) : () -> ()
     "omp.terminator"() : () -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index eb283840aa7ee..ff3b1e60f7cfe 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -617,6 +617,7 @@ func.func @omp_simd_pretty(%lb : index, %ub : index, %step : index) -> () {
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -632,6 +633,7 @@ func.func @omp_simd_pretty_aligned(%lb : index, %ub : index, %step : index,
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -643,6 +645,7 @@ func.func @omp_simd_pretty_if(%lb : index, %ub : index, %step : index, %if_cond
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -656,6 +659,7 @@ func.func @omp_simd_pretty_nontemporal(%lb : index, %ub : index, %step : index,
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -667,18 +671,21 @@ func.func @omp_simd_pretty_order(%lb : index, %ub : index, %step : index) -> ()
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.simd order(reproducible:concurrent)
   omp.simd order(reproducible:concurrent) {
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.simd order(unconstrained:concurrent)
   omp.simd order(unconstrained:concurrent) {
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -690,6 +697,7 @@ func.func @omp_simd_pretty_simdlen(%lb : index, %ub : index, %step : index) -> (
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -701,6 +709,7 @@ func.func @omp_simd_pretty_safelen(%lb : index, %ub : index, %step : index) -> (
     omp.loop_nest (%iv): index = (%lb) to (%ub) step (%step) {
       omp.yield
     }
+    omp.terminator
   }
   return
 }
@@ -720,42 +729,49 @@ func.func @omp_distribute(%chunk_size : i32, %data_var : memref<i32>, %arg0 : i3
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute dist_schedule_static
   omp.distribute dist_schedule_static {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute dist_schedule_static chunk_size(%{{.+}} : i32)
   omp.distribute dist_schedule_static chunk_size(%chunk_size : i32) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute order(concurrent)
   omp.distribute order(concurrent) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute order(reproducible:concurrent)
   omp.distribute order(reproducible:concurrent) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute order(unconstrained:concurrent)
   omp.distribute order(unconstrained:concurrent) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute allocate(%{{.+}} : memref<i32> -> %{{.+}} : memref<i32>)
   omp.distribute allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
     omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
       omp.yield
     }
+    omp.terminator
   }
   // CHECK: omp.distribute
   omp.distribute {
@@ -763,7 +779,9 @@ func.func @omp_distribute(%chunk_size : i32, %data_var : memref<i32>, %arg0 : i3
       omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
         omp.yield
       }
+      omp.terminator
     }
+    omp.terminator
   }
   return
 }
@@ -2278,6 +2296,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testbool = "test.bool"() : () -> (i1)
@@ -2288,6 +2307,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop final(%{{[^)]+}}) {
@@ -2296,6 +2316,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop untied {
@@ -2304,6 +2325,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop mergeable {
@@ -2312,6 +2334,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testf32 = "test.f32"() : () -> (!llvm.ptr)
@@ -2322,6 +2345,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // Checking byref attribute for in_reduction
@@ -2331,6 +2355,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop reduction(byref @add_f32 -> %{{.+}} : !llvm.ptr, @add_f32 -> %{{.+}} : !llvm.ptr) {
@@ -2339,6 +2364,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // check byref attrbute for reduction
@@ -2348,6 +2374,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop in_reduction(@add_f32 -> %{{.+}} : !llvm.ptr) reduction(@add_f32 -> %{{.+}} : !llvm.ptr) {
@@ -2356,6 +2383,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testi32 = "test.i32"() : () -> (i32)
@@ -2365,6 +2393,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testmemref = "test.memref"() : () -> (memref<i32>)
@@ -2374,6 +2403,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   %testi64 = "test.i64"() : () -> (i64)
@@ -2383,6 +2413,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop num_tasks(%{{[^:]+}}: i64) {
@@ -2391,6 +2422,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop nogroup {
@@ -2399,6 +2431,7 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
       // CHECK: omp.yield
       omp.yield
     }
+    omp.terminator
   }
 
   // CHECK: omp.taskloop {
@@ -2408,7 +2441,9 @@ func.func @omp_taskloop(%lb: i32, %ub: i32, %step: i32) -> () {
         // CHECK: omp.yield
         omp.yield
       }
+      omp.terminator
     }
+    omp.terminator
   }
 
   // CHECK: return
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 321de67aa48a1..dfeaf4be33adb 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -726,6 +726,7 @@ llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64
       llvm.store %3, %5 : f32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -749,6 +750,7 @@ llvm.func @simd_simple_multiple_simdlen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l
       llvm.store %3, %5 : f32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -769,6 +771,7 @@ llvm.func @simd_simple_multiple_safelen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l
       llvm.store %3, %5 : f32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -788,6 +791,7 @@ llvm.func @simd_simple_multiple_simdlen_safelen(%lb1 : i64, %ub1 : i64, %step1 :
       llvm.store %3, %5 : f32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -816,6 +820,7 @@ llvm.func @simd_if(%arg0: !llvm.ptr {fir.bindc_name = "n"}, %arg1: !llvm.ptr {fi
       llvm.store %arg2, %1 : i32, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }
@@ -836,6 +841,7 @@ llvm.func @simd_order() {
       llvm.store %arg0, %2 : i64, !llvm.ptr
       omp.yield
     }
+    omp.terminator
   }
   llvm.return
 }

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

This patch replaces the `SingleBlockImplicitTerminator<"TerminatorOp">` trait
of loop wrapper operations for the `SingleBlock` trait. This enables a more
robust implementation of the `LoopWrapperInterface::isWrapper()` method, since
it does no longer have to deal with the potentially missing (implicit)
terminator.

The `LoopWrapperInterface::isWrapper()` method is also extended to not identify
as wrappers those operations which have a loop wrapper operation inside that is
not taking a wrapper role. This is important for cases where `omp.parallel`
is nested, which can but is not required to work as a loop wrapper.

Tests are updated to integrate these representation and validation changes.
@skatrak skatrak force-pushed the fix-loopwrapperiface branch from 0e48712 to f666bb9 Compare July 8, 2024 09:46
@skatrak skatrak merged commit d6fb899 into llvm:main Jul 8, 2024
5 of 6 checks passed
@skatrak skatrak deleted the fix-loopwrapperiface branch July 8, 2024 10:21
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jul 8, 2024

LLVM Buildbot has detected a new failure on builder mlir-nvidia-gcc7 running on mlir-nvidia while building flang,mlir at step 5 "build-check-mlir-build-only".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/982

Here is the relevant piece of the build log for the reference:

Step 5 (build-check-mlir-build-only) failure: build (failure)
...
54.940 [81/16/4381] Linking CXX static library lib/libMLIRTestIR.a
54.963 [80/16/4382] Linking CXX static library lib/libMLIRTestPass.a
54.977 [79/16/4383] Linking CXX static library lib/libMLIRTestPDLL.a
55.000 [78/16/4384] Linking CXX static library lib/libMLIRTestTransforms.a
55.017 [77/16/4385] Linking CXX static library lib/libMyExtensionCh2.a
55.037 [76/16/4386] Building MyExtensionTypes.h.inc...
55.056 [75/16/4387] Building MyExtensionTypes.cpp.inc...
55.079 [74/16/4388] Building MyExtension.h.inc...
55.102 [73/16/4389] Building MyExtension.cpp.inc...
55.128 [72/16/4390] Building CXX object tools/mlir/examples/minimal-opt/CMakeFiles/mlir-minimal-opt.dir/mlir-minimal-opt.cpp.o
command timed out: 1200 seconds without output running [b'ninja', b'-j', b'16', b'check-mlir-build-only'], attempting to kill
process killed by signal 9
program finished with exit code -1
elapsedTime=2352.164172

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:openmp flang Flang issues not falling into any other category mlir:llvm mlir:openmp mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants