Skip to content

[mlir][mesh] Make sharding propagation and spmdization work on FuncOpInterface #84415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 8, 2024

Conversation

sogartar
Copy link
Contributor

@sogartar sogartar commented Mar 8, 2024

Make them more general instead of only supporting func::FuncOp.

@llvmbot
Copy link
Member

llvmbot commented Mar 8, 2024

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

Changes

Make them more general instead of only supporting func::FuncOp.


Full diff: https://github.com/llvm/llvm-project/pull/84415.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td (+2-2)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+10-8)
  • (modified) mlir/test/Dialect/Linalg/mesh-spmdization.mlir (+1-2)
  • (modified) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+1-1)
  • (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+3-1)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index 7fb6631574b410..06ebf151e7d649 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -16,7 +16,7 @@ include "mlir/Pass/PassBase.td"
 // ShardingPropagation
 //===----------------------------------------------------------------------===//
 
-def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
+def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionOpInterface"> {
   let summary = "sharding propagation";
   let description = [{
     Propagates sharding information throughout the graph. After this pass, each
@@ -29,7 +29,7 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
   ];
 }
 
-def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
+def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> {
   let summary = "Partition a function into SPMD form.";
   let description = [{
     This pass fits in right after a pass that annotates the function with
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 9f2647b21cbfc8..29320f1e339f86 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "llvm/Support/Debug.h"
 #include <vector>
@@ -172,9 +173,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
 struct ShardingPropagation
     : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
   void runOnOperation() override {
-    func::FuncOp funcOp = getOperation();
+    FunctionOpInterface funcOp = getOperation();
     MLIRContext *ctx = funcOp.getContext();
-    Region &region = funcOp.getBody();
+    Region &region = funcOp.getFunctionBody();
     OpBuilder builder(ctx);
     if (!region.hasOneBlock()) {
       funcOp.emitOpError() << "only one block is supported!";
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index c4d8b0b15e462c..e4868435135ed1 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -24,6 +24,8 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/Value.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
@@ -694,7 +696,7 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
 }
 
 static LogicalResult
-spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
+spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
               SymbolTableCollection &symbolTableCollection) {
   OpBuilder builder(op.getFunctionBody());
 
@@ -717,21 +719,21 @@ spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
 
   // Find a return op and change the function results signature to its operands
   // signature.
-  func::ReturnOp returnOp;
-  for (Block &block : op.getBody()) {
+  Operation *returnOp = nullptr;
+  for (Block &block : op.getFunctionBody()) {
     if (block.empty()) {
       continue;
     }
 
-    returnOp = llvm::cast<func::ReturnOp>(block.back());
-    if (returnOp) {
+    if (block.back().hasTrait<OpTrait::ReturnLike>()) {
+      returnOp = &block.back();
       break;
     }
   }
   assert(returnOp);
-  op.setFunctionType(FunctionType::get(op->getContext(),
-                                       op.getBody().front().getArgumentTypes(),
-                                       returnOp->getOperandTypes()));
+  op.setType(FunctionType::get(op->getContext(),
+                               op.getFunctionBody().front().getArgumentTypes(),
+                               returnOp->getOperandTypes()));
 
   return success();
 }
diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
index 6d21def8de2753..bd56c801283b17 100644
--- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
@@ -1,6 +1,5 @@
 // RUN: mlir-opt \
-// RUN:  --mesh-spmdization \
-// RUN:  --test-constant-fold \
+// RUN:  --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
 // RUN:  --split-input-file \
 // RUN:  %s | FileCheck %s
 
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 94f8d94073c5ef..270787ab518831 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -sharding-propagation %s | FileCheck %s
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
 
 mesh.mesh @mesh_1d(shape = ?)
 mesh.mesh @mesh_2d(shape = 2x4)
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 572d3eb55eaaae..2df247aba35155 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt --mesh-spmdization --test-constant-fold %s | FileCheck %s
+// RUN: mlir-opt \
+// RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
+// RUN:   %s | FileCheck %s
 
 mesh.mesh @mesh_1d(shape = 2)
 

@llvmbot
Copy link
Member

llvmbot commented Mar 8, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Boian Petkantchin (sogartar)

Changes

Make them more general instead of only supporting func::FuncOp.


Full diff: https://github.com/llvm/llvm-project/pull/84415.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td (+2-2)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+10-8)
  • (modified) mlir/test/Dialect/Linalg/mesh-spmdization.mlir (+1-2)
  • (modified) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+1-1)
  • (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+3-1)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index 7fb6631574b410..06ebf151e7d649 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -16,7 +16,7 @@ include "mlir/Pass/PassBase.td"
 // ShardingPropagation
 //===----------------------------------------------------------------------===//
 
-def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
+def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionOpInterface"> {
   let summary = "sharding propagation";
   let description = [{
     Propagates sharding information throughout the graph. After this pass, each
@@ -29,7 +29,7 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
   ];
 }
 
-def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
+def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> {
   let summary = "Partition a function into SPMD form.";
   let description = [{
     This pass fits in right after a pass that annotates the function with
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 9f2647b21cbfc8..29320f1e339f86 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "llvm/Support/Debug.h"
 #include <vector>
@@ -172,9 +173,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
 struct ShardingPropagation
     : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
   void runOnOperation() override {
-    func::FuncOp funcOp = getOperation();
+    FunctionOpInterface funcOp = getOperation();
     MLIRContext *ctx = funcOp.getContext();
-    Region &region = funcOp.getBody();
+    Region &region = funcOp.getFunctionBody();
     OpBuilder builder(ctx);
     if (!region.hasOneBlock()) {
       funcOp.emitOpError() << "only one block is supported!";
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index c4d8b0b15e462c..e4868435135ed1 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -24,6 +24,8 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/Value.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
@@ -694,7 +696,7 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
 }
 
 static LogicalResult
-spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
+spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
               SymbolTableCollection &symbolTableCollection) {
   OpBuilder builder(op.getFunctionBody());
 
@@ -717,21 +719,21 @@ spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
 
   // Find a return op and change the function results signature to its operands
   // signature.
-  func::ReturnOp returnOp;
-  for (Block &block : op.getBody()) {
+  Operation *returnOp = nullptr;
+  for (Block &block : op.getFunctionBody()) {
     if (block.empty()) {
       continue;
     }
 
-    returnOp = llvm::cast<func::ReturnOp>(block.back());
-    if (returnOp) {
+    if (block.back().hasTrait<OpTrait::ReturnLike>()) {
+      returnOp = &block.back();
       break;
     }
   }
   assert(returnOp);
-  op.setFunctionType(FunctionType::get(op->getContext(),
-                                       op.getBody().front().getArgumentTypes(),
-                                       returnOp->getOperandTypes()));
+  op.setType(FunctionType::get(op->getContext(),
+                               op.getFunctionBody().front().getArgumentTypes(),
+                               returnOp->getOperandTypes()));
 
   return success();
 }
diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
index 6d21def8de2753..bd56c801283b17 100644
--- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
@@ -1,6 +1,5 @@
 // RUN: mlir-opt \
-// RUN:  --mesh-spmdization \
-// RUN:  --test-constant-fold \
+// RUN:  --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
 // RUN:  --split-input-file \
 // RUN:  %s | FileCheck %s
 
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 94f8d94073c5ef..270787ab518831 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -sharding-propagation %s | FileCheck %s
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
 
 mesh.mesh @mesh_1d(shape = ?)
 mesh.mesh @mesh_2d(shape = 2x4)
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 572d3eb55eaaae..2df247aba35155 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt --mesh-spmdization --test-constant-fold %s | FileCheck %s
+// RUN: mlir-opt \
+// RUN:   --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
+// RUN:   %s | FileCheck %s
 
 mesh.mesh @mesh_1d(shape = 2)
 

@sogartar
Copy link
Contributor Author

sogartar commented Mar 8, 2024

@yaochengji, could review this PR?

@sogartar sogartar merged commit abfac56 into llvm:main Mar 8, 2024
mmilanifard added a commit to mmilanifard/llvm-project that referenced this pull request Mar 8, 2024
Adding missing dependencies for llvm#84415.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants