Skip to content

Revert "[mlir][mesh] adding option for traversal order in sharding propagation" #145531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2025

Conversation

qinkunbao
Copy link
Member

@llvmbot llvmbot added the mlir label Jun 24, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2025

@llvm/pr-subscribers-mlir

Author: Qinkun Bao (qinkunbao)

Changes

Reverts llvm/llvm-project#144079

Buildbot failure: https://lab.llvm.org/buildbot/#/builders/164/builds/11140


Full diff: https://github.com/llvm/llvm-project/pull/145531.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+3)
  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h (-12)
  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td (-15)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+12-15)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+13-29)
  • (removed) mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir (-26)
  • (removed) mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir (-27)
  • (removed) mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir (-49)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index c4d512b60bc51..1dc178586e918 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -206,6 +206,9 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
 // Use newShardOp if it is not null. Otherwise create a new one.
 // May insert resharding if required.
 // Potentially updates newShardOp.
+void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+                                         OpOperand &operand, OpBuilder &builder,
+                                         ShardOp &newShardOp);
 void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
                                          OpBuilder &builder);
 void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
index a2424d43a8ba9..83399d10beaae 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -19,18 +19,6 @@ class FuncOp;
 
 namespace mesh {
 
-/// This enum controls the traversal order for the sharding propagation.
-enum class TraversalOrder {
-  /// Forward traversal.
-  Forward,
-  /// Backward traversal.
-  Backward,
-  /// Forward then backward traversal.
-  ForwardBackward,
-  /// Backward then forward traversal.
-  BackwardForward
-};
-
 //===----------------------------------------------------------------------===//
 // Passes
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index 11ec7e78cd5e6..06ebf151e7d64 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -24,21 +24,6 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
     operation, and the operations themselves are added with sharding option
     attributes.
   }];
-  let options = [
-    Option<"traversal", "traversal",
-           "mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward",
-           "Traversal order to use for sharding propagation:",
-            [{::llvm::cl::values(
-              clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward",
-              "Forward only traversal."),
-              clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward",
-              "backward only traversal."),
-              clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward",
-              "forward-backward traversal."),
-              clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward",
-              "backward-forward traversal.")
-            )}]>,
-  ];
   let dependentDialects = [
     "mesh::MeshDialect"
   ];
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index b8cc91da722f0..0a01aaf776e7d 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -298,12 +298,13 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
   return type;
 }
 
-static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
-                                                    Value &operandValue,
-                                                    Operation *operandOp,
-                                                    OpBuilder &builder,
-                                                    ShardOp &newShardOp) {
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+                                                     OpOperand &operand,
+                                                     OpBuilder &builder,
+                                                     ShardOp &newShardOp) {
   OpBuilder::InsertionGuard insertionGuard(builder);
+  Value operandValue = operand.get();
+  Operation *operandOp = operand.getOwner();
   builder.setInsertionPointAfterValue(operandValue);
   ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
   if (shardOp && sharding == shardOp.getSharding() &&
@@ -322,8 +323,9 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
         builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
                                 /*annotate_for_users*/ false);
   }
-  operandValue.replaceUsesWithIf(
-      newShardOp, [operandOp, operandValue](OpOperand &use) {
+  IRRewriter rewriter(builder);
+  rewriter.replaceUsesWithIf(
+      operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
         return use.getOwner() == operandOp && use.get() == operandValue;
       });
 
@@ -334,20 +336,15 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
   auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
                                              newShardOp.getSharding(),
                                              /*annotate_for_users*/ true);
-  newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
+  rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
 }
 
 void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                                      OpResult result,
                                                      OpBuilder &builder) {
   ShardOp newShardOp;
-  SmallVector<std::pair<Value, Operation *>> uses;
-  for (auto &use : result.getUses()) {
-    uses.emplace_back(use.get(), use.getOwner());
-  }
-  for (auto &[operandValue, operandOp] : uses) {
-    maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
-                                            builder, newShardOp);
+  for (auto &use : llvm::make_early_inc_range(result.getUses())) {
+    maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
   }
 }
 
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 6751fafaf1776..4452dd65fce9d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -362,9 +362,6 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
 //===----------------------------------------------------------------------===//
 struct ShardingPropagation
     : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
-
-  using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
-
   void runOnOperation() override {
     FunctionOpInterface funcOp = getOperation();
     MLIRContext *ctx = funcOp.getContext();
@@ -385,31 +382,18 @@ struct ShardingPropagation
             shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
         });
 
-    auto traverse = [&](auto &&range, OpBuilder &builder,
-                        const char *order) -> bool {
-      for (Operation &op : range) {
-        if (failed(visitOp(&op, builder))) {
-          signalPassFailure();
-          return true;
-        }
-      }
-      LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n"
-                        << funcOp << "\n");
-      LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
-      return false;
-    };
-
-    // 1. Propagate in reversed order.
-    if (traversal == TraversalOrder::Backward ||
-        traversal == TraversalOrder::BackwardForward)
-      traverse(llvm::reverse(block), builder, "backward");
-
-    // 2. Propagate in original order.
-    if (traversal != TraversalOrder::Backward)
-      traverse(block, builder, "forward");
-
-    // 3. Propagate in backward order if needed.
-    if (traversal == TraversalOrder::ForwardBackward)
-      traverse(llvm::reverse(block), builder, "backward");
+    // 1. propagate in reversed order
+    for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
+      if (failed(visitOp(&op, builder)))
+        return signalPassFailure();
+
+    LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
+                      << funcOp << "\n");
+    LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+
+    // 2. propagate in original order
+    for (Operation &op : llvm::make_early_inc_range(block))
+      if (failed(visitOp(&op, builder)))
+        return signalPassFailure();
   }
 };
diff --git a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
deleted file mode 100644
index 4223d01d65111..0000000000000
--- a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
+++ /dev/null
@@ -1,26 +0,0 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=backward}))" %s | FileCheck %s
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-module {
-  mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
-  func.func @test_forward() -> tensor<6x6xi32> {
-    %c1_i32 = arith.constant 1 : i32
-    // CHECK: tensor.empty()
-    %0 = tensor.empty() : tensor<6x6xi32>
-    %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
-    // CHECK-COUNT-2: mesh.shard
-    %sharding_annotated = mesh.shard %0 to %sharding : tensor<6x6xi32>
-    %1 = linalg.fill ins(%c1_i32 : i32) outs(%sharding_annotated : tensor<6x6xi32>) -> tensor<6x6xi32>
-    // CHECK: tensor.empty()
-    // CHECK-NOT: mesh.shard @
-    %2 = tensor.empty() : tensor<6x6xi32>
-    %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %1
-        : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
-    ^bb0(%in: i32, %in_2: i32, %out: i32):
-      %9 = arith.addi %in, %in_2 : i32
-      linalg.yield %9 : i32
-    } -> tensor<6x6xi32>
-    // CHECK: return
-    return %3 : tensor<6x6xi32>
-  }
-}
diff --git a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
deleted file mode 100644
index dd2eee2f7def8..0000000000000
--- a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
+++ /dev/null
@@ -1,27 +0,0 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward-backward}))" %s | FileCheck %s
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-module {
-  mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
-  func.func @test_forward() -> tensor<6x6xi32> {
-    %c1_i32 = arith.constant 1 : i32
-    // CHECK: tensor.empty()
-    %0 = tensor.empty() : tensor<6x6xi32>
-    // CHECK-COUNT-3: mesh.sharding @mesh split_axes = {{\[\[0}}]]
-    %sharding_row = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
-    %annotated_row = mesh.shard %0 to %sharding_row : tensor<6x6xi32>
-    %1 = linalg.fill ins(%c1_i32 : i32) outs(%annotated_row : tensor<6x6xi32>) -> tensor<6x6xi32>
-    %2 = tensor.empty() : tensor<6x6xi32>
-    // CHECK-COUNT-4: mesh.sharding @mesh split_axes = {{\[\[1}}]]
-    %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %1
-        : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
-    ^bb0(%in: i32, %in_2: i32, %out: i32):
-      %9 = arith.addi %in, %in_2 : i32
-      linalg.yield %9 : i32
-    } -> tensor<6x6xi32>
-    %sharding_col = mesh.sharding @mesh split_axes = [[1]] : !mesh.sharding
-    %annotated_col = mesh.shard %3 to %sharding_col : tensor<6x6xi32>
-    // CHECK: return
-    return %annotated_col : tensor<6x6xi32>
-  }
-}
diff --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
deleted file mode 100644
index 98e9931b8de94..0000000000000
--- a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
+++ /dev/null
@@ -1,49 +0,0 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} {
-  mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
-  func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} {
-    %c1_i32 = arith.constant 1 : i32
-    // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
-    %0 = tensor.empty() : tensor<6x6xi32>
-    // CHECK: [[v1:%.*]] = linalg.fill ins
-    // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
-    %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32>
-    %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
-    %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32>
-    // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
-    // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
-    %3 = tensor.empty() : tensor<6x6xi32>
-    // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
-    // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
-    // CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) {
-    // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
-    %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated
-        : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
-    ^bb0(%in: i32, %in_2: i32, %out: i32):
-      %9 = arith.addi %in, %in_2 : i32
-      linalg.yield %9 : i32
-    } -> tensor<6x6xi32>
-    %c0_i32 = arith.constant 0 : i32
-    %6 = tensor.empty() : tensor<i32>
-    %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
-    // CHECK: [[vreduced:%.*]] = linalg.reduce ins
-    // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] partial =  sum [0] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
-    %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1] 
-      (%in: i32, %init: i32) {
-        %9 = arith.addi %in, %init : i32
-        linalg.yield %9 : i32
-      }
-    // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding
-    %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
-    %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
-    return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>
-  }
-}

@qinkunbao qinkunbao merged commit b0ef912 into main Jun 24, 2025
7 of 9 checks passed
@qinkunbao qinkunbao deleted the revert-144079-fwbw branch June 24, 2025 15:26
DrSergei pushed a commit to DrSergei/llvm-project that referenced this pull request Jun 24, 2025
fschlimb added a commit that referenced this pull request Jun 26, 2025
#144079 introduced a test with an uninitialized access
Buildbot failure:
https://lab.llvm.org/buildbot/#/builders/164/builds/11140
and got reverted #145531

This PR is an exact copy of #144079 plus a trivial fix
(96c8525).
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
llvm#144079 introduced a test with an uninitialized access
Buildbot failure:
https://lab.llvm.org/buildbot/#/builders/164/builds/11140
and got reverted llvm#145531

This PR is an exact copy of llvm#144079 plus a trivial fix
(96c8525).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants