Skip to content

Revert "[mlir][mesh] adding option for traversal order in sharding propagation" #145531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
// Use newShardOp if it is not null. Otherwise create a new one.
// May insert resharding if required.
// Potentially updates newShardOp.
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand, OpBuilder &builder,
ShardOp &newShardOp);
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
Expand Down
12 changes: 0 additions & 12 deletions mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,6 @@ class FuncOp;

namespace mesh {

/// This enum controls the traversal order for the sharding propagation.
enum class TraversalOrder {
/// Forward traversal.
Forward,
/// Backward traversal.
Backward,
/// Forward then backward traversal.
ForwardBackward,
/// Backward then forward traversal.
BackwardForward
};

//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
Expand Down
15 changes: 0 additions & 15 deletions mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,6 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
operation, and the operations themselves are added with sharding option
attributes.
}];
let options = [
Option<"traversal", "traversal",
"mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward",
"Traversal order to use for sharding propagation:",
[{::llvm::cl::values(
clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward",
"Forward only traversal."),
clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward",
"backward only traversal."),
clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward",
"forward-backward traversal."),
clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward",
"backward-forward traversal.")
)}]>,
];
let dependentDialects = [
"mesh::MeshDialect"
];
Expand Down
27 changes: 12 additions & 15 deletions mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,13 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
return type;
}

static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
Value &operandValue,
Operation *operandOp,
OpBuilder &builder,
ShardOp &newShardOp) {
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpOperand &operand,
OpBuilder &builder,
ShardOp &newShardOp) {
OpBuilder::InsertionGuard insertionGuard(builder);
Value operandValue = operand.get();
Operation *operandOp = operand.getOwner();
builder.setInsertionPointAfterValue(operandValue);
ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
if (shardOp && sharding == shardOp.getSharding() &&
Expand All @@ -322,8 +323,9 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
/*annotate_for_users*/ false);
}
operandValue.replaceUsesWithIf(
newShardOp, [operandOp, operandValue](OpOperand &use) {
IRRewriter rewriter(builder);
rewriter.replaceUsesWithIf(
operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
return use.getOwner() == operandOp && use.get() == operandValue;
});

Expand All @@ -334,20 +336,15 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
newShardOp.getSharding(),
/*annotate_for_users*/ true);
newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
}

void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpResult result,
OpBuilder &builder) {
ShardOp newShardOp;
SmallVector<std::pair<Value, Operation *>> uses;
for (auto &use : result.getUses()) {
uses.emplace_back(use.get(), use.getOwner());
}
for (auto &[operandValue, operandOp] : uses) {
maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
builder, newShardOp);
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
}
}

Expand Down
42 changes: 13 additions & 29 deletions mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,6 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
//===----------------------------------------------------------------------===//
struct ShardingPropagation
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> {

using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;

void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();
MLIRContext *ctx = funcOp.getContext();
Expand All @@ -385,31 +382,18 @@ struct ShardingPropagation
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
});

auto traverse = [&](auto &&range, OpBuilder &builder,
const char *order) -> bool {
for (Operation &op : range) {
if (failed(visitOp(&op, builder))) {
signalPassFailure();
return true;
}
}
LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n"
<< funcOp << "\n");
LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
return false;
};

// 1. Propagate in reversed order.
if (traversal == TraversalOrder::Backward ||
traversal == TraversalOrder::BackwardForward)
traverse(llvm::reverse(block), builder, "backward");

// 2. Propagate in original order.
if (traversal != TraversalOrder::Backward)
traverse(block, builder, "forward");

// 3. Propagate in backward order if needed.
if (traversal == TraversalOrder::ForwardBackward)
traverse(llvm::reverse(block), builder, "backward");
// 1. propagate in reversed order
for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
if (failed(visitOp(&op, builder)))
return signalPassFailure();

LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
<< funcOp << "\n");
LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));

// 2. propagate in original order
for (Operation &op : llvm::make_early_inc_range(block))
if (failed(visitOp(&op, builder)))
return signalPassFailure();
}
};
26 changes: 0 additions & 26 deletions mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir

This file was deleted.

27 changes: 0 additions & 27 deletions mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir

This file was deleted.

49 changes: 0 additions & 49 deletions mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir

This file was deleted.

Loading