Skip to content

[OpenMP][MLIR] Set omp.composite attr for composite loop wrappers and add verifier checks #102341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 12, 2024

Conversation

TIFitis
Copy link
Member

@TIFitis TIFitis commented Aug 7, 2024

This patch sets the omp.composite unit attr for composite wrapper ops and also add appropriate checks to the verifiers of supported ops for the presence/absence of the attribute.

This is patch 2/2 in a series of patches.

@llvmbot
Copy link
Member

llvmbot commented Aug 7, 2024

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-mlir

Author: Akash Banerjee (TIFitis)

Changes

This patch sets the omp.composite unit attr for composite wrapper ops and also add appropriate checks to the verifiers of supported ops for the presence/absence of the attribute.

This is patch 2/2 in a series of patches.


Full diff: https://github.com/llvm/llvm-project/pull/102341.diff

4 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+8)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+13-5)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+36)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+32)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index bbde77c14f36a1..3b18e7b3ecf80e 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2063,10 +2063,14 @@ static void genCompositeDistributeSimd(
   // TODO: Populate entry block arguments with private variables.
   auto distributeOp = genWrapperOp<mlir::omp::DistributeOp>(
       converter, loc, distributeClauseOps, /*blockArgTypes=*/{});
+  llvm::cast<mlir::omp::ComposableOpInterface>(distributeOp.getOperation())
+      .setComposite(/*val=*/true);
 
   // TODO: Populate entry block arguments with reduction and private variables.
   auto simdOp = genWrapperOp<mlir::omp::SimdOp>(converter, loc, simdClauseOps,
                                                 /*blockArgTypes=*/{});
+  llvm::cast<mlir::omp::ComposableOpInterface>(simdOp.getOperation())
+      .setComposite(/*val=*/true);
 
   // Construct wrapper entry block list and associated symbols. It is important
   // that the symbol order and the block argument order match, so that the
@@ -2111,10 +2115,14 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter,
   // TODO: Add private variables to entry block arguments.
   auto wsloopOp = genWrapperOp<mlir::omp::WsloopOp>(
       converter, loc, wsloopClauseOps, wsloopReductionTypes);
+  llvm::cast<mlir::omp::ComposableOpInterface>(wsloopOp.getOperation())
+      .setComposite(/*val=*/true);
 
   // TODO: Populate entry block arguments with reduction and private variables.
   auto simdOp = genWrapperOp<mlir::omp::SimdOp>(converter, loc, simdClauseOps,
                                                 /*blockArgTypes=*/{});
+  llvm::cast<mlir::omp::ComposableOpInterface>(simdOp.getOperation())
+      .setComposite(/*val=*/true);
 
   // Construct wrapper entry block list and associated symbols. It is important
   // that the symbol and block argument order match, so that the symbol-value
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 68f92e6952694b..d63fdd88f79104 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -128,6 +128,7 @@ def PrivateClauseOp : OpenMP_Op<"private", [IsolatedFromAbove, RecipeInterface]>
 
 def ParallelOp : OpenMP_Op<"parallel", traits = [
     AttrSizedOperandSegments, AutomaticAllocationScope,
+    DeclareOpInterfaceMethods<ComposableOpInterface>,
     DeclareOpInterfaceMethods<LoopWrapperInterface>,
     DeclareOpInterfaceMethods<OutlineableOpenMPOpInterface>,
     RecursiveMemoryEffects
@@ -356,7 +357,9 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
 //===----------------------------------------------------------------------===//
 
 def WsloopOp : OpenMP_Op<"wsloop", traits = [
-    AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
+    AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<ComposableOpInterface>,
+    DeclareOpInterfaceMethods<LoopWrapperInterface>,
     RecursiveMemoryEffects, SingleBlock
   ], clauses = [
     OpenMP_AllocateClauseSkip<assemblyFormat = true>,
@@ -432,7 +435,9 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
 //===----------------------------------------------------------------------===//
 
 def SimdOp : OpenMP_Op<"simd", traits = [
-    AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
+    AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<ComposableOpInterface>,
+    DeclareOpInterfaceMethods<LoopWrapperInterface>,
     RecursiveMemoryEffects, SingleBlock
   ], clauses = [
     OpenMP_AlignedClause, OpenMP_IfClause, OpenMP_LinearClause,
@@ -499,7 +504,9 @@ def YieldOp : OpenMP_Op<"yield",
 // Distribute construct [2.9.4.1]
 //===----------------------------------------------------------------------===//
 def DistributeOp : OpenMP_Op<"distribute", traits = [
-    AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
+    AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<ComposableOpInterface>,
+    DeclareOpInterfaceMethods<LoopWrapperInterface>,
     RecursiveMemoryEffects, SingleBlock
   ], clauses = [
     OpenMP_AllocateClause, OpenMP_DistScheduleClause, OpenMP_OrderClause,
@@ -587,8 +594,9 @@ def TaskOp : OpenMP_Op<"task", traits = [
 
 def TaskloopOp : OpenMP_Op<"taskloop", traits = [
     AttrSizedOperandSegments, AutomaticAllocationScope,
-    DeclareOpInterfaceMethods<LoopWrapperInterface>, RecursiveMemoryEffects,
-    SingleBlock
+    DeclareOpInterfaceMethods<ComposableOpInterface>,
+    DeclareOpInterfaceMethods<LoopWrapperInterface>,
+    RecursiveMemoryEffects, SingleBlock
   ], clauses = [
     OpenMP_AllocateClause, OpenMP_FinalClause, OpenMP_GrainsizeClause,
     OpenMP_IfClause, OpenMP_InReductionClauseSkip<extraClassDeclaration = true>,
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 2d1de37239c82a..4cf847f55be215 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -142,6 +142,42 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
   ];
 }
 
+def ComposableOpInterface : OpInterface<"ComposableOpInterface"> {
+  let description = [{
+    OpenMP operations that can represent a single leaf of a composite OpenMP
+    construct.
+  }];
+
+  let cppNamespace = "::mlir::omp";
+
+  let methods = [
+    InterfaceMethod<
+      /*description=*/[{
+        Check whether the operation is representing a leaf of a composite OpenMP
+        construct.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isComposite",
+      (ins ), [{}], [{
+        return $_op->hasAttr("omp.composite");
+      }]
+    >,
+    InterfaceMethod<
+      /*description=*/[{
+        Mark the operation as part of an OpenMP composite construct.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"setComposite",
+      (ins "bool":$val), [{}], [{
+        if(val)
+          $_op->setDiscardableAttr("omp.composite", mlir::UnitAttr::get($_op->getContext()));
+        else
+          $_op->removeDiscardableAttr("omp.composite");
+      }]
+    >
+  ];
+}
+
 def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
   let description = [{
     OpenMP operations that support declare target have this interface.
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 11780f84697b15..641fbb5a1418f6 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1546,6 +1546,9 @@ LogicalResult ParallelOp::verify() {
     if (!isWrapper())
       return emitOpError() << "must take a loop wrapper role if nested inside "
                               "of 'omp.distribute'";
+    if (!llvm::cast<ComposableOpInterface>(getOperation()).isComposite())
+      return emitError()
+             << "'omp.composite' attribute missing from composite wrapper";
 
     if (LoopWrapperInterface nested = getNestedWrapper()) {
       // Check for the allowed leaf constructs that may appear in a composite
@@ -1555,6 +1558,9 @@ LogicalResult ParallelOp::verify() {
     } else {
       return emitOpError() << "must not wrap an 'omp.loop_nest' directly";
     }
+  } else if (llvm::cast<ComposableOpInterface>(getOperation()).isComposite()) {
+    return emitError()
+           << "'omp.composite' attribute present in non-composite wrapper";
   }
 
   if (getAllocateVars().size() != getAllocatorVars().size())
@@ -1749,10 +1755,18 @@ LogicalResult WsloopOp::verify() {
     return emitOpError() << "must be a loop wrapper";
 
   if (LoopWrapperInterface nested = getNestedWrapper()) {
+    if (!llvm::cast<ComposableOpInterface>(getOperation()).isComposite())
+      return emitError()
+             << "'omp.composite' attribute missing from composite wrapper";
+
     // Check for the allowed leaf constructs that may appear in a composite
     // construct directly after DO/FOR.
     if (!isa<SimdOp>(nested))
       return emitError() << "only supported nested wrapper is 'omp.simd'";
+
+  } else if (llvm::cast<ComposableOpInterface>(getOperation()).isComposite()) {
+    return emitError()
+           << "'omp.composite' attribute present in non-composite wrapper";
   }
 
   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
@@ -1796,6 +1810,10 @@ LogicalResult SimdOp::verify() {
   if (getNestedWrapper())
     return emitOpError() << "must wrap an 'omp.loop_nest' directly";
 
+  if (llvm::cast<ComposableOpInterface>(getOperation()).isComposite())
+    return emitError()
+           << "'omp.composite' attribute present in non-composite wrapper";
+
   return success();
 }
 
@@ -1825,11 +1843,17 @@ LogicalResult DistributeOp::verify() {
     return emitOpError() << "must be a loop wrapper";
 
   if (LoopWrapperInterface nested = getNestedWrapper()) {
+    if (!llvm::cast<ComposableOpInterface>(getOperation()).isComposite())
+      return emitError()
+             << "'omp.composite' attribute missing from composite wrapper";
     // Check for the allowed leaf constructs that may appear in a composite
     // construct directly after DISTRIBUTE.
     if (!isa<ParallelOp, SimdOp>(nested))
       return emitError() << "only supported nested wrappers are 'omp.parallel' "
                             "and 'omp.simd'";
+  } else if (llvm::cast<ComposableOpInterface>(getOperation()).isComposite()) {
+    return emitError()
+           << "'omp.composite' attribute present in non-composite wrapper";
   }
 
   return success();
@@ -2031,11 +2055,19 @@ LogicalResult TaskloopOp::verify() {
     return emitOpError() << "must be a loop wrapper";
 
   if (LoopWrapperInterface nested = getNestedWrapper()) {
+    if (!llvm::cast<ComposableOpInterface>(getOperation()).isComposite())
+      return emitError()
+             << "'omp.composite' attribute missing from composite wrapper";
+
     // Check for the allowed leaf constructs that may appear in a composite
     // construct directly after TASKLOOP.
     if (!isa<SimdOp>(nested))
       return emitError() << "only supported nested wrapper is 'omp.simd'";
+  } else if (llvm::cast<ComposableOpInterface>(getOperation()).isComposite()) {
+    return emitError()
+           << "'omp.composite' attribute present in non-composite wrapper";
   }
+
   return success();
 }
 

Copy link
Member

@skatrak skatrak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for picking up this work. In general it looks good, but apart from some nits I think some improvements can be made to the op verifiers as well.

Also, please update mlir/test/Dialect/OpenMP/ops.mlir to add the omp.composite attribute to all ops representing a composite construct, as well as adding test cases to mlir/test/Dialect/OpenMP/invalid.mlir to test each of the new error cases.

@TIFitis
Copy link
Member Author

TIFitis commented Aug 8, 2024

Thanks @skatrak for the review, I've updated the PR to address the comments.

Copy link
Member

@skatrak skatrak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you again Akash, this LGTM. I have a couple of minor nits left, but there's no need for another review by me after addressing them. Can you add to ops.mlir an instance of distribute parallel do/for simd after line 117?

TIFitis added a commit that referenced this pull request Aug 12, 2024
…te unitAttr. (#102340)

Adds a new ComposableOpInterface for OpenMP operations that can
represent a single leaf of a composite OpenMP construct.

This is patch 1/2 in a series of patches. Patch 2 - #102341.
Base automatically changed from users/akash/composable-op to main August 12, 2024 14:35
@TIFitis TIFitis merged commit f2f4193 into main Aug 12, 2024
3 of 4 checks passed
@TIFitis TIFitis deleted the users/akash/set-composite branch August 12, 2024 14:36
TIFitis added a commit that referenced this pull request Aug 12, 2024
Add small test that I missed adding to #102341.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category mlir:openmp mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants