-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Revert "[mlir][mesh] adding option for traversal order in sharding propagation" #145531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…opagatio…" This reverts commit 43e1a5a.
@llvm/pr-subscribers-mlir Author: Qinkun Bao (qinkunbao) ChangesReverts llvm/llvm-project#144079 Buildbot failure: https://lab.llvm.org/buildbot/#/builders/164/builds/11140 Full diff: https://github.com/llvm/llvm-project/pull/145531.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index c4d512b60bc51..1dc178586e918 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -206,6 +206,9 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
// Use newShardOp if it is not null. Otherwise create a new one.
// May insert resharding if required.
// Potentially updates newShardOp.
+void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+ OpOperand &operand, OpBuilder &builder,
+ ShardOp &newShardOp);
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
index a2424d43a8ba9..83399d10beaae 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -19,18 +19,6 @@ class FuncOp;
namespace mesh {
-/// This enum controls the traversal order for the sharding propagation.
-enum class TraversalOrder {
- /// Forward traversal.
- Forward,
- /// Backward traversal.
- Backward,
- /// Forward then backward traversal.
- ForwardBackward,
- /// Backward then forward traversal.
- BackwardForward
-};
-
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index 11ec7e78cd5e6..06ebf151e7d64 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -24,21 +24,6 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
operation, and the operations themselves are added with sharding option
attributes.
}];
- let options = [
- Option<"traversal", "traversal",
- "mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward",
- "Traversal order to use for sharding propagation:",
- [{::llvm::cl::values(
- clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward",
- "Forward only traversal."),
- clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward",
- "backward only traversal."),
- clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward",
- "forward-backward traversal."),
- clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward",
- "backward-forward traversal.")
- )}]>,
- ];
let dependentDialects = [
"mesh::MeshDialect"
];
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index b8cc91da722f0..0a01aaf776e7d 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -298,12 +298,13 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
return type;
}
-static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
- Value &operandValue,
- Operation *operandOp,
- OpBuilder &builder,
- ShardOp &newShardOp) {
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+ OpOperand &operand,
+ OpBuilder &builder,
+ ShardOp &newShardOp) {
OpBuilder::InsertionGuard insertionGuard(builder);
+ Value operandValue = operand.get();
+ Operation *operandOp = operand.getOwner();
builder.setInsertionPointAfterValue(operandValue);
ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
if (shardOp && sharding == shardOp.getSharding() &&
@@ -322,8 +323,9 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
/*annotate_for_users*/ false);
}
- operandValue.replaceUsesWithIf(
- newShardOp, [operandOp, operandValue](OpOperand &use) {
+ IRRewriter rewriter(builder);
+ rewriter.replaceUsesWithIf(
+ operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
return use.getOwner() == operandOp && use.get() == operandValue;
});
@@ -334,20 +336,15 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
newShardOp.getSharding(),
/*annotate_for_users*/ true);
- newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
+ rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
}
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpResult result,
OpBuilder &builder) {
ShardOp newShardOp;
- SmallVector<std::pair<Value, Operation *>> uses;
- for (auto &use : result.getUses()) {
- uses.emplace_back(use.get(), use.getOwner());
- }
- for (auto &[operandValue, operandOp] : uses) {
- maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
- builder, newShardOp);
+ for (auto &use : llvm::make_early_inc_range(result.getUses())) {
+ maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
}
}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 6751fafaf1776..4452dd65fce9d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -362,9 +362,6 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
//===----------------------------------------------------------------------===//
struct ShardingPropagation
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
-
- using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
-
void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();
MLIRContext *ctx = funcOp.getContext();
@@ -385,31 +382,18 @@ struct ShardingPropagation
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
});
- auto traverse = [&](auto &&range, OpBuilder &builder,
- const char *order) -> bool {
- for (Operation &op : range) {
- if (failed(visitOp(&op, builder))) {
- signalPassFailure();
- return true;
- }
- }
- LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n"
- << funcOp << "\n");
- LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
- return false;
- };
-
- // 1. Propagate in reversed order.
- if (traversal == TraversalOrder::Backward ||
- traversal == TraversalOrder::BackwardForward)
- traverse(llvm::reverse(block), builder, "backward");
-
- // 2. Propagate in original order.
- if (traversal != TraversalOrder::Backward)
- traverse(block, builder, "forward");
-
- // 3. Propagate in backward order if needed.
- if (traversal == TraversalOrder::ForwardBackward)
- traverse(llvm::reverse(block), builder, "backward");
+ // 1. propagate in reversed order
+ for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
+ if (failed(visitOp(&op, builder)))
+ return signalPassFailure();
+
+ LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
+ << funcOp << "\n");
+ LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+
+ // 2. propagate in original order
+ for (Operation &op : llvm::make_early_inc_range(block))
+ if (failed(visitOp(&op, builder)))
+ return signalPassFailure();
}
};
diff --git a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
deleted file mode 100644
index 4223d01d65111..0000000000000
--- a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
+++ /dev/null
@@ -1,26 +0,0 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=backward}))" %s | FileCheck %s
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-module {
- mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
- func.func @test_forward() -> tensor<6x6xi32> {
- %c1_i32 = arith.constant 1 : i32
- // CHECK: tensor.empty()
- %0 = tensor.empty() : tensor<6x6xi32>
- %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
- // CHECK-COUNT-2: mesh.shard
- %sharding_annotated = mesh.shard %0 to %sharding : tensor<6x6xi32>
- %1 = linalg.fill ins(%c1_i32 : i32) outs(%sharding_annotated : tensor<6x6xi32>) -> tensor<6x6xi32>
- // CHECK: tensor.empty()
- // CHECK-NOT: mesh.shard @
- %2 = tensor.empty() : tensor<6x6xi32>
- %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %1
- : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
- ^bb0(%in: i32, %in_2: i32, %out: i32):
- %9 = arith.addi %in, %in_2 : i32
- linalg.yield %9 : i32
- } -> tensor<6x6xi32>
- // CHECK: return
- return %3 : tensor<6x6xi32>
- }
-}
diff --git a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
deleted file mode 100644
index dd2eee2f7def8..0000000000000
--- a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
+++ /dev/null
@@ -1,27 +0,0 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward-backward}))" %s | FileCheck %s
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-module {
- mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
- func.func @test_forward() -> tensor<6x6xi32> {
- %c1_i32 = arith.constant 1 : i32
- // CHECK: tensor.empty()
- %0 = tensor.empty() : tensor<6x6xi32>
- // CHECK-COUNT-3: mesh.sharding @mesh split_axes = {{\[\[0}}]]
- %sharding_row = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
- %annotated_row = mesh.shard %0 to %sharding_row : tensor<6x6xi32>
- %1 = linalg.fill ins(%c1_i32 : i32) outs(%annotated_row : tensor<6x6xi32>) -> tensor<6x6xi32>
- %2 = tensor.empty() : tensor<6x6xi32>
- // CHECK-COUNT-4: mesh.sharding @mesh split_axes = {{\[\[1}}]]
- %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %1
- : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
- ^bb0(%in: i32, %in_2: i32, %out: i32):
- %9 = arith.addi %in, %in_2 : i32
- linalg.yield %9 : i32
- } -> tensor<6x6xi32>
- %sharding_col = mesh.sharding @mesh split_axes = [[1]] : !mesh.sharding
- %annotated_col = mesh.shard %3 to %sharding_col : tensor<6x6xi32>
- // CHECK: return
- return %annotated_col : tensor<6x6xi32>
- }
-}
diff --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
deleted file mode 100644
index 98e9931b8de94..0000000000000
--- a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
+++ /dev/null
@@ -1,49 +0,0 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} {
- mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
- func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} {
- %c1_i32 = arith.constant 1 : i32
- // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
- %0 = tensor.empty() : tensor<6x6xi32>
- // CHECK: [[v1:%.*]] = linalg.fill ins
- // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
- // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
- %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32>
- %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
- %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32>
- // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
- // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
- // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
- %3 = tensor.empty() : tensor<6x6xi32>
- // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
- // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
- // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
- // CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) {
- // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
- // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
- %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated
- : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
- ^bb0(%in: i32, %in_2: i32, %out: i32):
- %9 = arith.addi %in, %in_2 : i32
- linalg.yield %9 : i32
- } -> tensor<6x6xi32>
- %c0_i32 = arith.constant 0 : i32
- %6 = tensor.empty() : tensor<i32>
- %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
- // CHECK: [[vreduced:%.*]] = linalg.reduce ins
- // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] partial = sum [0] : !mesh.sharding
- // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
- %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1]
- (%in: i32, %init: i32) {
- %9 = arith.addi %in, %init : i32
- linalg.yield %9 : i32
- }
- // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding
- %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
- // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
- %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
- return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>
- }
-}
|
DrSergei
pushed a commit
to DrSergei/llvm-project
that referenced
this pull request
Jun 24, 2025
…opagation" (llvm#145531) Reverts llvm#144079 Buildbot failure: https://lab.llvm.org/buildbot/#/builders/164/builds/11140
fschlimb
added a commit
that referenced
this pull request
Jun 26, 2025
anthonyhatran
pushed a commit
to anthonyhatran/llvm-project
that referenced
this pull request
Jun 26, 2025
…opagation" (llvm#145531) Reverts llvm#144079 Buildbot failure: https://lab.llvm.org/buildbot/#/builders/164/builds/11140
anthonyhatran
pushed a commit
to anthonyhatran/llvm-project
that referenced
this pull request
Jun 26, 2025
llvm#144079 introduced a test with an uninitialized access Buildbot failure: https://lab.llvm.org/buildbot/#/builders/164/builds/11140 and got reverted llvm#145531 This PR is an exact copy of llvm#144079 plus a trivial fix (96c8525).
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Reverts #144079
Buildbot failure: https://lab.llvm.org/buildbot/#/builders/164/builds/11140