-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][mesh] Make sharding propagation and spmdization work on FuncOpInterface #84415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Boian Petkantchin (sogartar) ChangesMake them more general instead of only supporting Full diff: https://github.com/llvm/llvm-project/pull/84415.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index 7fb6631574b410..06ebf151e7d649 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -16,7 +16,7 @@ include "mlir/Pass/PassBase.td"
// ShardingPropagation
//===----------------------------------------------------------------------===//
-def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
+def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionOpInterface"> {
let summary = "sharding propagation";
let description = [{
Propagates sharding information throughout the graph. After this pass, each
@@ -29,7 +29,7 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
];
}
-def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
+def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> {
let summary = "Partition a function into SPMD form.";
let description = [{
This pass fits in right after a pass that annotates the function with
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 9f2647b21cbfc8..29320f1e339f86 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
#include <vector>
@@ -172,9 +173,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
struct ShardingPropagation
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
void runOnOperation() override {
- func::FuncOp funcOp = getOperation();
+ FunctionOpInterface funcOp = getOperation();
MLIRContext *ctx = funcOp.getContext();
- Region ®ion = funcOp.getBody();
+ Region ®ion = funcOp.getFunctionBody();
OpBuilder builder(ctx);
if (!region.hasOneBlock()) {
funcOp.emitOpError() << "only one block is supported!";
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index c4d8b0b15e462c..e4868435135ed1 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -24,6 +24,8 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
@@ -694,7 +696,7 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
}
static LogicalResult
-spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
+spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection) {
OpBuilder builder(op.getFunctionBody());
@@ -717,21 +719,21 @@ spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
// Find a return op and change the function results signature to its operands
// signature.
- func::ReturnOp returnOp;
- for (Block &block : op.getBody()) {
+ Operation *returnOp = nullptr;
+ for (Block &block : op.getFunctionBody()) {
if (block.empty()) {
continue;
}
- returnOp = llvm::cast<func::ReturnOp>(block.back());
- if (returnOp) {
+ if (block.back().hasTrait<OpTrait::ReturnLike>()) {
+ returnOp = &block.back();
break;
}
}
assert(returnOp);
- op.setFunctionType(FunctionType::get(op->getContext(),
- op.getBody().front().getArgumentTypes(),
- returnOp->getOperandTypes()));
+ op.setType(FunctionType::get(op->getContext(),
+ op.getFunctionBody().front().getArgumentTypes(),
+ returnOp->getOperandTypes()));
return success();
}
diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
index 6d21def8de2753..bd56c801283b17 100644
--- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
@@ -1,6 +1,5 @@
// RUN: mlir-opt \
-// RUN: --mesh-spmdization \
-// RUN: --test-constant-fold \
+// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
// RUN: --split-input-file \
// RUN: %s | FileCheck %s
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 94f8d94073c5ef..270787ab518831 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -sharding-propagation %s | FileCheck %s
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
mesh.mesh @mesh_1d(shape = ?)
mesh.mesh @mesh_2d(shape = 2x4)
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 572d3eb55eaaae..2df247aba35155 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt --mesh-spmdization --test-constant-fold %s | FileCheck %s
+// RUN: mlir-opt \
+// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
+// RUN: %s | FileCheck %s
mesh.mesh @mesh_1d(shape = 2)
|
@llvm/pr-subscribers-mlir-linalg Author: Boian Petkantchin (sogartar) ChangesMake them more general instead of only supporting Full diff: https://github.com/llvm/llvm-project/pull/84415.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index 7fb6631574b410..06ebf151e7d649 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -16,7 +16,7 @@ include "mlir/Pass/PassBase.td"
// ShardingPropagation
//===----------------------------------------------------------------------===//
-def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
+def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionOpInterface"> {
let summary = "sharding propagation";
let description = [{
Propagates sharding information throughout the graph. After this pass, each
@@ -29,7 +29,7 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
];
}
-def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
+def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> {
let summary = "Partition a function into SPMD form.";
let description = [{
This pass fits in right after a pass that annotates the function with
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 9f2647b21cbfc8..29320f1e339f86 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
#include <vector>
@@ -172,9 +173,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
struct ShardingPropagation
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
void runOnOperation() override {
- func::FuncOp funcOp = getOperation();
+ FunctionOpInterface funcOp = getOperation();
MLIRContext *ctx = funcOp.getContext();
- Region ®ion = funcOp.getBody();
+ Region ®ion = funcOp.getFunctionBody();
OpBuilder builder(ctx);
if (!region.hasOneBlock()) {
funcOp.emitOpError() << "only one block is supported!";
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index c4d8b0b15e462c..e4868435135ed1 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -24,6 +24,8 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
@@ -694,7 +696,7 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
}
static LogicalResult
-spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
+spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection) {
OpBuilder builder(op.getFunctionBody());
@@ -717,21 +719,21 @@ spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
// Find a return op and change the function results signature to its operands
// signature.
- func::ReturnOp returnOp;
- for (Block &block : op.getBody()) {
+ Operation *returnOp = nullptr;
+ for (Block &block : op.getFunctionBody()) {
if (block.empty()) {
continue;
}
- returnOp = llvm::cast<func::ReturnOp>(block.back());
- if (returnOp) {
+ if (block.back().hasTrait<OpTrait::ReturnLike>()) {
+ returnOp = &block.back();
break;
}
}
assert(returnOp);
- op.setFunctionType(FunctionType::get(op->getContext(),
- op.getBody().front().getArgumentTypes(),
- returnOp->getOperandTypes()));
+ op.setType(FunctionType::get(op->getContext(),
+ op.getFunctionBody().front().getArgumentTypes(),
+ returnOp->getOperandTypes()));
return success();
}
diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
index 6d21def8de2753..bd56c801283b17 100644
--- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
@@ -1,6 +1,5 @@
// RUN: mlir-opt \
-// RUN: --mesh-spmdization \
-// RUN: --test-constant-fold \
+// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
// RUN: --split-input-file \
// RUN: %s | FileCheck %s
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 94f8d94073c5ef..270787ab518831 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -sharding-propagation %s | FileCheck %s
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
mesh.mesh @mesh_1d(shape = ?)
mesh.mesh @mesh_2d(shape = 2x4)
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 572d3eb55eaaae..2df247aba35155 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt --mesh-spmdization --test-constant-fold %s | FileCheck %s
+// RUN: mlir-opt \
+// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
+// RUN: %s | FileCheck %s
mesh.mesh @mesh_1d(shape = 2)
|
@yaochengji, could review this PR? |
Adding missing dependencies for llvm#84415.
Make them more general instead of only supporting
func::FuncOp
.