Skip to content

[mlir][mesh] resubmitting #144079 #145897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class MeshSharding {
::mlir::FlatSymbolRefAttr mesh;
SmallVector<MeshAxesAttr> split_axes;
SmallVector<MeshAxis> partial_axes;
ReductionKind partial_type;
ReductionKind partial_type = ReductionKind::Sum;
SmallVector<int64_t> static_halo_sizes;
SmallVector<int64_t> static_sharded_dims_offsets;
SmallVector<Value> dynamic_halo_sizes;
Expand Down Expand Up @@ -206,9 +206,6 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
// Use newShardOp if it is not null. Otherwise create a new one.
// May insert resharding if required.
// Potentially updates newShardOp.
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand, OpBuilder &builder,
ShardOp &newShardOp);
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ class FuncOp;

namespace mesh {

/// This enum controls the traversal order for the sharding propagation.
enum class TraversalOrder {
/// Forward traversal.
Forward,
/// Backward traversal.
Backward,
/// Forward then backward traversal.
ForwardBackward,
/// Backward then forward traversal.
BackwardForward
};

//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
Expand Down
15 changes: 15 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
operation, and the operations themselves are added with sharding option
attributes.
}];
let options = [
Option<"traversal", "traversal",
"mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward",
"Traversal order to use for sharding propagation:",
[{::llvm::cl::values(
clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward",
"Forward only traversal."),
clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward",
"backward only traversal."),
clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward",
"forward-backward traversal."),
clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward",
"backward-forward traversal.")
)}]>,
];
let dependentDialects = [
"mesh::MeshDialect"
];
Expand Down
27 changes: 15 additions & 12 deletions mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,12 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
return type;
}

void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder,
ShardOp &newShardOp) {
static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
Value &operandValue,
Operation *operandOp,
OpBuilder &builder,
ShardOp &newShardOp) {
OpBuilder::InsertionGuard insertionGuard(builder);
Value operandValue = operand.get();
Operation *operandOp = operand.getOwner();
builder.setInsertionPointAfterValue(operandValue);
ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
if (shardOp && sharding == shardOp.getSharding() &&
Expand All @@ -323,9 +322,8 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
/*annotate_for_users*/ false);
}
IRRewriter rewriter(builder);
rewriter.replaceUsesWithIf(
operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
operandValue.replaceUsesWithIf(
newShardOp, [operandOp, operandValue](OpOperand &use) {
return use.getOwner() == operandOp && use.get() == operandValue;
});

Expand All @@ -336,15 +334,20 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
newShardOp.getSharding(),
/*annotate_for_users*/ true);
rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
}

void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpResult result,
OpBuilder &builder) {
ShardOp newShardOp;
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
SmallVector<std::pair<Value, Operation *>> uses;
for (auto &use : result.getUses()) {
uses.emplace_back(use.get(), use.getOwner());
}
for (auto &[operandValue, operandOp] : uses) {
maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
builder, newShardOp);
}
}

Expand Down
42 changes: 29 additions & 13 deletions mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
//===----------------------------------------------------------------------===//
struct ShardingPropagation
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> {

using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;

void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();
MLIRContext *ctx = funcOp.getContext();
Expand All @@ -382,18 +385,31 @@ struct ShardingPropagation
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
});

// 1. propagate in reversed order
for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
if (failed(visitOp(&op, builder)))
return signalPassFailure();

LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
<< funcOp << "\n");
LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));

// 2. propagate in original order
for (Operation &op : llvm::make_early_inc_range(block))
if (failed(visitOp(&op, builder)))
return signalPassFailure();
auto traverse = [&](auto &&range, OpBuilder &builder,
const char *order) -> bool {
for (Operation &op : range) {
if (failed(visitOp(&op, builder))) {
signalPassFailure();
return true;
}
}
LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n"
<< funcOp << "\n");
LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
return false;
};

// 1. Propagate in reversed order.
if (traversal == TraversalOrder::Backward ||
traversal == TraversalOrder::BackwardForward)
traverse(llvm::reverse(block), builder, "backward");

// 2. Propagate in original order.
if (traversal != TraversalOrder::Backward)
traverse(block, builder, "forward");

// 3. Propagate in backward order if needed.
if (traversal == TraversalOrder::ForwardBackward)
traverse(llvm::reverse(block), builder, "backward");
}
};
26 changes: 26 additions & 0 deletions mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=backward}))" %s | FileCheck %s

#map = affine_map<(d0, d1) -> (d0, d1)>
module {
mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
func.func @test_forward() -> tensor<6x6xi32> {
%c1_i32 = arith.constant 1 : i32
// CHECK: tensor.empty()
%0 = tensor.empty() : tensor<6x6xi32>
%sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
// CHECK-COUNT-2: mesh.shard
%sharding_annotated = mesh.shard %0 to %sharding : tensor<6x6xi32>
%1 = linalg.fill ins(%c1_i32 : i32) outs(%sharding_annotated : tensor<6x6xi32>) -> tensor<6x6xi32>
// CHECK: tensor.empty()
// CHECK-NOT: mesh.shard @
%2 = tensor.empty() : tensor<6x6xi32>
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %1
: tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
^bb0(%in: i32, %in_2: i32, %out: i32):
%9 = arith.addi %in, %in_2 : i32
linalg.yield %9 : i32
} -> tensor<6x6xi32>
// CHECK: return
return %3 : tensor<6x6xi32>
}
}
27 changes: 27 additions & 0 deletions mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward-backward}))" %s | FileCheck %s

#map = affine_map<(d0, d1) -> (d0, d1)>
module {
mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
func.func @test_forward() -> tensor<6x6xi32> {
%c1_i32 = arith.constant 1 : i32
// CHECK: tensor.empty()
%0 = tensor.empty() : tensor<6x6xi32>
// CHECK-COUNT-3: mesh.sharding @mesh split_axes = {{\[\[0}}]]
%sharding_row = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
%annotated_row = mesh.shard %0 to %sharding_row : tensor<6x6xi32>
%1 = linalg.fill ins(%c1_i32 : i32) outs(%annotated_row : tensor<6x6xi32>) -> tensor<6x6xi32>
%2 = tensor.empty() : tensor<6x6xi32>
// CHECK-COUNT-4: mesh.sharding @mesh split_axes = {{\[\[1}}]]
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %1
: tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
^bb0(%in: i32, %in_2: i32, %out: i32):
%9 = arith.addi %in, %in_2 : i32
linalg.yield %9 : i32
} -> tensor<6x6xi32>
%sharding_col = mesh.sharding @mesh split_axes = [[1]] : !mesh.sharding
%annotated_col = mesh.shard %3 to %sharding_col : tensor<6x6xi32>
// CHECK: return
return %annotated_col : tensor<6x6xi32>
}
}
49 changes: 49 additions & 0 deletions mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s

#map = affine_map<(d0, d1) -> (d0, d1)>
module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} {
mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} {
%c1_i32 = arith.constant 1 : i32
// CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
%0 = tensor.empty() : tensor<6x6xi32>
// CHECK: [[v1:%.*]] = linalg.fill ins
// CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
%1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32>
%sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
%sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32>
// CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
// CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
%3 = tensor.empty() : tensor<6x6xi32>
// CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
// CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) {
// CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
%4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated
: tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
^bb0(%in: i32, %in_2: i32, %out: i32):
%9 = arith.addi %in, %in_2 : i32
linalg.yield %9 : i32
} -> tensor<6x6xi32>
%c0_i32 = arith.constant 0 : i32
%6 = tensor.empty() : tensor<i32>
%7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
// CHECK: [[vreduced:%.*]] = linalg.reduce ins
// CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] partial = sum [0] : !mesh.sharding
// CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
%reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1]
(%in: i32, %init: i32) {
%9 = arith.addi %in, %init : i32
linalg.yield %9 : i32
}
// CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding
%sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
// CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
%sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>
}
}