Skip to content

[mlir][mesh] resubmitting #144079 #145897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 26, 2025
Merged

[mlir][mesh] resubmitting #144079 #145897

merged 5 commits into from
Jun 26, 2025

Conversation

fschlimb
Copy link
Contributor

@fschlimb fschlimb commented Jun 26, 2025

#144079 introduced a test with an uninitialized access
Buildbot failure: https://lab.llvm.org/buildbot/#/builders/164/builds/11140
and got reverted #145531

This PR is an exact copy of #144079 plus a trivial fix (96c8525).

@llvmbot llvmbot added the mlir label Jun 26, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

Changes

#144079 introduced a test with an uninitialized access
Buildbot failure: https://lab.llvm.org/buildbot/#/builders/164/builds/11140
and got reverted b0ef912

This PR is an exact copy of #144079 plus a trivial fix (96c8525).


Full diff: https://github.com/llvm/llvm-project/pull/145897.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+1-4)
  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h (+12)
  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td (+15)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+15-12)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+29-13)
  • (added) mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir (+26)
  • (added) mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir (+27)
  • (added) mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir (+49)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 1dc178586e918..7213fde45c695 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -44,7 +44,7 @@ class MeshSharding {
   ::mlir::FlatSymbolRefAttr mesh;
   SmallVector<MeshAxesAttr> split_axes;
   SmallVector<MeshAxis> partial_axes;
-  ReductionKind partial_type;
+  ReductionKind partial_type = ReductionKind::Sum;
   SmallVector<int64_t> static_halo_sizes;
   SmallVector<int64_t> static_sharded_dims_offsets;
   SmallVector<Value> dynamic_halo_sizes;
@@ -206,9 +206,6 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
 // Use newShardOp if it is not null. Otherwise create a new one.
 // May insert resharding if required.
 // Potentially updates newShardOp.
-void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
-                                         OpOperand &operand, OpBuilder &builder,
-                                         ShardOp &newShardOp);
 void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
                                          OpBuilder &builder);
 void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
index 83399d10beaae..a2424d43a8ba9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -19,6 +19,18 @@ class FuncOp;
 
 namespace mesh {
 
+/// This enum controls the traversal order for the sharding propagation.
+enum class TraversalOrder {
+  /// Forward traversal.
+  Forward,
+  /// Backward traversal.
+  Backward,
+  /// Forward then backward traversal.
+  ForwardBackward,
+  /// Backward then forward traversal.
+  BackwardForward
+};
+
 //===----------------------------------------------------------------------===//
 // Passes
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index 06ebf151e7d64..11ec7e78cd5e6 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -24,6 +24,21 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
     operation, and the operations themselves are added with sharding option
     attributes.
   }];
+  let options = [
+    Option<"traversal", "traversal",
+           "mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward",
+           "Traversal order to use for sharding propagation:",
+            [{::llvm::cl::values(
+              clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward",
+              "Forward only traversal."),
+              clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward",
+              "backward only traversal."),
+              clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward",
+              "forward-backward traversal."),
+              clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward",
+              "backward-forward traversal.")
+            )}]>,
+  ];
   let dependentDialects = [
     "mesh::MeshDialect"
   ];
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 0a01aaf776e7d..b8cc91da722f0 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -298,13 +298,12 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
   return type;
 }
 
-void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
-                                                     OpOperand &operand,
-                                                     OpBuilder &builder,
-                                                     ShardOp &newShardOp) {
+static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
+                                                    Value &operandValue,
+                                                    Operation *operandOp,
+                                                    OpBuilder &builder,
+                                                    ShardOp &newShardOp) {
   OpBuilder::InsertionGuard insertionGuard(builder);
-  Value operandValue = operand.get();
-  Operation *operandOp = operand.getOwner();
   builder.setInsertionPointAfterValue(operandValue);
   ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
   if (shardOp && sharding == shardOp.getSharding() &&
@@ -323,9 +322,8 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
         builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
                                 /*annotate_for_users*/ false);
   }
-  IRRewriter rewriter(builder);
-  rewriter.replaceUsesWithIf(
-      operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
+  operandValue.replaceUsesWithIf(
+      newShardOp, [operandOp, operandValue](OpOperand &use) {
         return use.getOwner() == operandOp && use.get() == operandValue;
       });
 
@@ -336,15 +334,20 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
   auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
                                              newShardOp.getSharding(),
                                              /*annotate_for_users*/ true);
-  rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+  newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
 }
 
 void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                                      OpResult result,
                                                      OpBuilder &builder) {
   ShardOp newShardOp;
-  for (auto &use : llvm::make_early_inc_range(result.getUses())) {
-    maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
+  SmallVector<std::pair<Value, Operation *>> uses;
+  for (auto &use : result.getUses()) {
+    uses.emplace_back(use.get(), use.getOwner());
+  }
+  for (auto &[operandValue, operandOp] : uses) {
+    maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
+                                            builder, newShardOp);
   }
 }
 
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 4452dd65fce9d..6751fafaf1776 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -362,6 +362,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
 //===----------------------------------------------------------------------===//
 struct ShardingPropagation
     : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
+
+  using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
+
   void runOnOperation() override {
     FunctionOpInterface funcOp = getOperation();
     MLIRContext *ctx = funcOp.getContext();
@@ -382,18 +385,31 @@ struct ShardingPropagation
             shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
         });
 
-    // 1. propagate in reversed order
-    for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
-      if (failed(visitOp(&op, builder)))
-        return signalPassFailure();
-
-    LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
-                      << funcOp << "\n");
-    LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
-
-    // 2. propagate in original order
-    for (Operation &op : llvm::make_early_inc_range(block))
-      if (failed(visitOp(&op, builder)))
-        return signalPassFailure();
+    auto traverse = [&](auto &&range, OpBuilder &builder,
+                        const char *order) -> bool {
+      for (Operation &op : range) {
+        if (failed(visitOp(&op, builder))) {
+          signalPassFailure();
+          return true;
+        }
+      }
+      LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n"
+                        << funcOp << "\n");
+      LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+      return false;
+    };
+
+    // 1. Propagate in reversed order.
+    if (traversal == TraversalOrder::Backward ||
+        traversal == TraversalOrder::BackwardForward)
+      traverse(llvm::reverse(block), builder, "backward");
+
+    // 2. Propagate in original order.
+    if (traversal != TraversalOrder::Backward)
+      traverse(block, builder, "forward");
+
+    // 3. Propagate in backward order if needed.
+    if (traversal == TraversalOrder::ForwardBackward)
+      traverse(llvm::reverse(block), builder, "backward");
   }
 };
diff --git a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
new file mode 100644
index 0000000000000..4223d01d65111
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=backward}))" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+  mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+  func.func @test_forward() -> tensor<6x6xi32> {
+    %c1_i32 = arith.constant 1 : i32
+    // CHECK: tensor.empty()
+    %0 = tensor.empty() : tensor<6x6xi32>
+    %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
+    // CHECK-COUNT-2: mesh.shard
+    %sharding_annotated = mesh.shard %0 to %sharding : tensor<6x6xi32>
+    %1 = linalg.fill ins(%c1_i32 : i32) outs(%sharding_annotated : tensor<6x6xi32>) -> tensor<6x6xi32>
+    // CHECK: tensor.empty()
+    // CHECK-NOT: mesh.shard @
+    %2 = tensor.empty() : tensor<6x6xi32>
+    %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %1
+        : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
+    ^bb0(%in: i32, %in_2: i32, %out: i32):
+      %9 = arith.addi %in, %in_2 : i32
+      linalg.yield %9 : i32
+    } -> tensor<6x6xi32>
+    // CHECK: return
+    return %3 : tensor<6x6xi32>
+  }
+}
diff --git a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
new file mode 100644
index 0000000000000..dd2eee2f7def8
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward-backward}))" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+  mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+  func.func @test_forward() -> tensor<6x6xi32> {
+    %c1_i32 = arith.constant 1 : i32
+    // CHECK: tensor.empty()
+    %0 = tensor.empty() : tensor<6x6xi32>
+    // CHECK-COUNT-3: mesh.sharding @mesh split_axes = {{\[\[0}}]]
+    %sharding_row = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
+    %annotated_row = mesh.shard %0 to %sharding_row : tensor<6x6xi32>
+    %1 = linalg.fill ins(%c1_i32 : i32) outs(%annotated_row : tensor<6x6xi32>) -> tensor<6x6xi32>
+    %2 = tensor.empty() : tensor<6x6xi32>
+    // CHECK-COUNT-4: mesh.sharding @mesh split_axes = {{\[\[1}}]]
+    %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %1
+        : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
+    ^bb0(%in: i32, %in_2: i32, %out: i32):
+      %9 = arith.addi %in, %in_2 : i32
+      linalg.yield %9 : i32
+    } -> tensor<6x6xi32>
+    %sharding_col = mesh.sharding @mesh split_axes = [[1]] : !mesh.sharding
+    %annotated_col = mesh.shard %3 to %sharding_col : tensor<6x6xi32>
+    // CHECK: return
+    return %annotated_col : tensor<6x6xi32>
+  }
+}
diff --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
new file mode 100644
index 0000000000000..98e9931b8de94
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} {
+  mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+  func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} {
+    %c1_i32 = arith.constant 1 : i32
+    // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
+    %0 = tensor.empty() : tensor<6x6xi32>
+    // CHECK: [[v1:%.*]] = linalg.fill ins
+    // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
+    %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32>
+    %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
+    %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32>
+    // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
+    // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
+    %3 = tensor.empty() : tensor<6x6xi32>
+    // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
+    // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+    // CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) {
+    // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
+    %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated
+        : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
+    ^bb0(%in: i32, %in_2: i32, %out: i32):
+      %9 = arith.addi %in, %in_2 : i32
+      linalg.yield %9 : i32
+    } -> tensor<6x6xi32>
+    %c0_i32 = arith.constant 0 : i32
+    %6 = tensor.empty() : tensor<i32>
+    %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
+    // CHECK: [[vreduced:%.*]] = linalg.reduce ins
+    // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] partial =  sum [0] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
+    %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1] 
+      (%in: i32, %init: i32) {
+        %9 = arith.addi %in, %init : i32
+        linalg.yield %9 : i32
+      }
+    // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding
+    %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
+    %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
+    return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>
+  }
+}

@fschlimb fschlimb merged commit 2bece21 into llvm:main Jun 26, 2025
9 checks passed
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
llvm#144079 introduced a test with an uninitialized access
Buildbot failure:
https://lab.llvm.org/buildbot/#/builders/164/builds/11140
and got reverted llvm#145531

This PR is an exact copy of llvm#144079 plus a trivial fix
(96c8525).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants