Skip to content

[mlir][mesh] adding option for traversal order in sharding propagation #144079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 18, 2025

Conversation

fschlimb
Copy link
Contributor

The traversal order in sharding propagation was hard-coded. This PR provides options to the pass to select a suitable order

  • forward-only
  • backward-only
  • forward-backward
  • backward-forward

Default is the previous behavior (backward-forward).

FYI @tkarna

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jun 13, 2025
@fschlimb
Copy link
Contributor Author

FYI @yaochengji

@llvmbot
Copy link
Member

llvmbot commented Jun 13, 2025

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

Changes

The traversal order in sharding propagation was hard-coded. This PR provides options to the pass to select a suitable order

  • forward-only
  • backward-only
  • forward-backward
  • backward-forward

Default is the previous behavior (backward-forward).

FYI @tkarna


Full diff: https://github.com/llvm/llvm-project/pull/144079.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h (+12)
  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td (+15)
  • (modified) mlir/include/mlir/Transforms/Passes.h (+1)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+30-10)
  • (added) mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir (+49)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
index 83399d10beaae..a2424d43a8ba9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -19,6 +19,18 @@ class FuncOp;
 
 namespace mesh {
 
+/// This enum controls the traversal order for the sharding propagation.
+enum class TraversalOrder {
+  /// Forward traversal.
+  Forward,
+  /// Backward traversal.
+  Backward,
+  /// Forward then backward traversal.
+  ForwardBackward,
+  /// Backward then forward traversal.
+  BackwardForward
+};
+
 //===----------------------------------------------------------------------===//
 // Passes
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index 06ebf151e7d64..11ec7e78cd5e6 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -24,6 +24,21 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
     operation, and the operations themselves are added with sharding option
     attributes.
   }];
+  let options = [
+    Option<"traversal", "traversal",
+           "mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward",
+           "Traversal order to use for sharding propagation:",
+            [{::llvm::cl::values(
+              clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward",
+              "Forward only traversal."),
+              clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward",
+              "backward only traversal."),
+              clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward",
+              "forward-backward traversal."),
+              clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward",
+              "backward-forward traversal.")
+            )}]>,
+  ];
   let dependentDialects = [
     "mesh::MeshDialect"
   ];
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 41f208216374f..16cdbebf91900 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -18,6 +18,7 @@
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LocationSnapshot.h"
+// #include "mlir/Transforms/ShardingPropagationUtils.h"
 #include "mlir/Transforms/ViewOpGraph.h"
 #include "llvm/Support/Debug.h"
 #include <limits>
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 4452dd65fce9d..9d4a144912ee2 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -362,6 +362,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
 //===----------------------------------------------------------------------===//
 struct ShardingPropagation
     : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
+
+  using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
+
   void runOnOperation() override {
     FunctionOpInterface funcOp = getOperation();
     MLIRContext *ctx = funcOp.getContext();
@@ -383,17 +386,34 @@ struct ShardingPropagation
         });
 
     // 1. propagate in reversed order
-    for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
-      if (failed(visitOp(&op, builder)))
-        return signalPassFailure();
-
-    LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
-                      << funcOp << "\n");
-    LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+    if (traversal == TraversalOrder::Backward ||
+        traversal == TraversalOrder::BackwardForward) {
+      for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
+        if (failed(visitOp(&op, builder)))
+          return signalPassFailure();
+      if (traversal == TraversalOrder::BackwardForward) {
+        LLVM_DEBUG(DBGS() << "After backward order propagation:\n"
+                          << funcOp << "\n");
+        LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+      }
+    }
 
     // 2. propagate in original order
-    for (Operation &op : llvm::make_early_inc_range(block))
-      if (failed(visitOp(&op, builder)))
-        return signalPassFailure();
+    if (traversal != TraversalOrder::Backward) {
+      for (Operation &op : llvm::make_early_inc_range(block))
+        if (failed(visitOp(&op, builder)))
+          return signalPassFailure();
+      if (traversal == TraversalOrder::ForwardBackward) {
+        LLVM_DEBUG(DBGS() << "After forward order propagation:\n"
+                          << funcOp << "\n");
+        LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+      }
+    }
+
+    // 3. propagate in backward order if needed
+    if (traversal == TraversalOrder::ForwardBackward)
+      for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
+        if (failed(visitOp(&op, builder)))
+          return signalPassFailure();
   }
 };
diff --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
new file mode 100644
index 0000000000000..98e9931b8de94
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} {
+  mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+  func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} {
+    %c1_i32 = arith.constant 1 : i32
+    // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
+    %0 = tensor.empty() : tensor<6x6xi32>
+    // CHECK: [[v1:%.*]] = linalg.fill ins
+    // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
+    %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32>
+    %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
+    %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32>
+    // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
+    // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
+    %3 = tensor.empty() : tensor<6x6xi32>
+    // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
+    // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+    // CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) {
+    // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
+    %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated
+        : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
+    ^bb0(%in: i32, %in_2: i32, %out: i32):
+      %9 = arith.addi %in, %in_2 : i32
+      linalg.yield %9 : i32
+    } -> tensor<6x6xi32>
+    %c0_i32 = arith.constant 0 : i32
+    %6 = tensor.empty() : tensor<i32>
+    %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
+    // CHECK: [[vreduced:%.*]] = linalg.reduce ins
+    // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] partial =  sum [0] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
+    %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1] 
+      (%in: i32, %init: i32) {
+        %9 = arith.addi %in, %init : i32
+        linalg.yield %9 : i32
+      }
+    // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding
+    %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
+    %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
+    return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>
+  }
+}

@llvmbot
Copy link
Member

llvmbot commented Jun 13, 2025

@llvm/pr-subscribers-mlir-core

Author: Frank Schlimbach (fschlimb)

Changes

The traversal order in sharding propagation was hard-coded. This PR provides options to the pass to select a suitable order

  • forward-only
  • backward-only
  • forward-backward
  • backward-forward

Default is the previous behavior (backward-forward).

FYI @tkarna


Full diff: https://github.com/llvm/llvm-project/pull/144079.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h (+12)
  • (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td (+15)
  • (modified) mlir/include/mlir/Transforms/Passes.h (+1)
  • (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+30-10)
  • (added) mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir (+49)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
index 83399d10beaae..a2424d43a8ba9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -19,6 +19,18 @@ class FuncOp;
 
 namespace mesh {
 
+/// This enum controls the traversal order for the sharding propagation.
+enum class TraversalOrder {
+  /// Forward traversal.
+  Forward,
+  /// Backward traversal.
+  Backward,
+  /// Forward then backward traversal.
+  ForwardBackward,
+  /// Backward then forward traversal.
+  BackwardForward
+};
+
 //===----------------------------------------------------------------------===//
 // Passes
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index 06ebf151e7d64..11ec7e78cd5e6 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -24,6 +24,21 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
     operation, and the operations themselves are added with sharding option
     attributes.
   }];
+  let options = [
+    Option<"traversal", "traversal",
+           "mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward",
+           "Traversal order to use for sharding propagation:",
+            [{::llvm::cl::values(
+              clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward",
+              "Forward only traversal."),
+              clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward",
+              "backward only traversal."),
+              clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward",
+              "forward-backward traversal."),
+              clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward",
+              "backward-forward traversal.")
+            )}]>,
+  ];
   let dependentDialects = [
     "mesh::MeshDialect"
   ];
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 41f208216374f..16cdbebf91900 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -18,6 +18,7 @@
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LocationSnapshot.h"
+// #include "mlir/Transforms/ShardingPropagationUtils.h"
 #include "mlir/Transforms/ViewOpGraph.h"
 #include "llvm/Support/Debug.h"
 #include <limits>
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 4452dd65fce9d..9d4a144912ee2 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -362,6 +362,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
 //===----------------------------------------------------------------------===//
 struct ShardingPropagation
     : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
+
+  using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
+
   void runOnOperation() override {
     FunctionOpInterface funcOp = getOperation();
     MLIRContext *ctx = funcOp.getContext();
@@ -383,17 +386,34 @@ struct ShardingPropagation
         });
 
     // 1. propagate in reversed order
-    for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
-      if (failed(visitOp(&op, builder)))
-        return signalPassFailure();
-
-    LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
-                      << funcOp << "\n");
-    LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+    if (traversal == TraversalOrder::Backward ||
+        traversal == TraversalOrder::BackwardForward) {
+      for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
+        if (failed(visitOp(&op, builder)))
+          return signalPassFailure();
+      if (traversal == TraversalOrder::BackwardForward) {
+        LLVM_DEBUG(DBGS() << "After backward order propagation:\n"
+                          << funcOp << "\n");
+        LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+      }
+    }
 
     // 2. propagate in original order
-    for (Operation &op : llvm::make_early_inc_range(block))
-      if (failed(visitOp(&op, builder)))
-        return signalPassFailure();
+    if (traversal != TraversalOrder::Backward) {
+      for (Operation &op : llvm::make_early_inc_range(block))
+        if (failed(visitOp(&op, builder)))
+          return signalPassFailure();
+      if (traversal == TraversalOrder::ForwardBackward) {
+        LLVM_DEBUG(DBGS() << "After forward order propagation:\n"
+                          << funcOp << "\n");
+        LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+      }
+    }
+
+    // 3. propagate in backward order if needed
+    if (traversal == TraversalOrder::ForwardBackward)
+      for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
+        if (failed(visitOp(&op, builder)))
+          return signalPassFailure();
   }
 };
diff --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
new file mode 100644
index 0000000000000..98e9931b8de94
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} {
+  mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+  func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} {
+    %c1_i32 = arith.constant 1 : i32
+    // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
+    %0 = tensor.empty() : tensor<6x6xi32>
+    // CHECK: [[v1:%.*]] = linalg.fill ins
+    // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
+    %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32>
+    %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
+    %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32>
+    // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
+    // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
+    %3 = tensor.empty() : tensor<6x6xi32>
+    // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
+    // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+    // CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) {
+    // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
+    %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated
+        : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
+    ^bb0(%in: i32, %in_2: i32, %out: i32):
+      %9 = arith.addi %in, %in_2 : i32
+      linalg.yield %9 : i32
+    } -> tensor<6x6xi32>
+    %c0_i32 = arith.constant 0 : i32
+    %6 = tensor.empty() : tensor<i32>
+    %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
+    // CHECK: [[vreduced:%.*]] = linalg.reduce ins
+    // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] partial =  sum [0] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
+    %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1] 
+      (%in: i32, %init: i32) {
+        %9 = arith.addi %in, %init : i32
+        linalg.yield %9 : i32
+      }
+    // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding
+    %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
+    %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
+    return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>
+  }
+}

Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except for a nit and a question about test coverage, LGTM!

@@ -383,17 +386,34 @@ struct ShardingPropagation
});

// 1. propagate in reversed order
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: make into sentence, i.e. Capitalize and add full stop.

Copy link
Contributor

@tkarna tkarna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, minor comments.

Comment on lines 391 to 393
for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
if (failed(visitOp(&op, builder)))
return signalPassFailure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: This loop nest is replicated three times. Could be refactored by creating a queue and, depending on the traversal order, pushing block or llvm::reverse(block) to it, and then iterating over the queue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ranges have different template types, so this would be tricker than it seems.
However, I can deduplicate the loop bodies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deduplicated code.

@fschlimb fschlimb merged commit 43e1a5a into llvm:main Jun 18, 2025
7 checks passed
fschlimb added a commit to fschlimb/llvm-project that referenced this pull request Jun 18, 2025
llvm#144079)

The traversal order in sharding propagation was hard-coded. This PR
provides options to the pass to select a suitable order
- forward-only
- backward-only
- forward-backward
- backward-forward

Default is the previous behavior (backward-forward).
@qinkunbao
Copy link
Member

Hi, the test added in this PR was failed in the buildbot. Can you take a look?

https://lab.llvm.org/buildbot/#/builders/164/builds/11125

@fschlimb
Copy link
Contributor Author

Hi, the test added in this PR was failed in the buildbot. Can you take a look?

https://lab.llvm.org/buildbot/#/builders/164/builds/11125

Sorry for the delay, didn't see the notification about the issue or your message. Thank @qinkunbao for handling this in the meanwhile. Hope this doesn't repeat...
Fixed in #145897.

fschlimb added a commit that referenced this pull request Jun 26, 2025
#144079 introduced a test with an uninitialized access
Buildbot failure:
https://lab.llvm.org/buildbot/#/builders/164/builds/11140
and got reverted #145531

This PR is an exact copy of #144079 plus a trivial fix
(96c8525).
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
llvm#144079 introduced a test with an uninitialized access
Buildbot failure:
https://lab.llvm.org/buildbot/#/builders/164/builds/11140
and got reverted llvm#145531

This PR is an exact copy of llvm#144079 plus a trivial fix
(96c8525).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants