Skip to content

Commit b0ef912

Browse files
authored
Revert "[mlir][mesh] adding option for traversal order in sharding propagation" (#145531)
Reverts #144079 Buildbot failure: https://lab.llvm.org/buildbot/#/builders/164/builds/11140
1 parent 43d042b commit b0ef912

File tree

8 files changed

+28
-173
lines changed

8 files changed

+28
-173
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
206206
// Use newShardOp if it is not null. Otherwise create a new one.
207207
// May insert resharding if required.
208208
// Potentially updates newShardOp.
209+
void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
210+
OpOperand &operand, OpBuilder &builder,
211+
ShardOp &newShardOp);
209212
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
210213
OpBuilder &builder);
211214
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,

mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,6 @@ class FuncOp;
1919

2020
namespace mesh {
2121

22-
/// This enum controls the traversal order for the sharding propagation.
23-
enum class TraversalOrder {
24-
/// Forward traversal.
25-
Forward,
26-
/// Backward traversal.
27-
Backward,
28-
/// Forward then backward traversal.
29-
ForwardBackward,
30-
/// Backward then forward traversal.
31-
BackwardForward
32-
};
33-
3422
//===----------------------------------------------------------------------===//
3523
// Passes
3624
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,6 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
2424
operation, and the operations themselves are added with sharding option
2525
attributes.
2626
}];
27-
let options = [
28-
Option<"traversal", "traversal",
29-
"mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward",
30-
"Traversal order to use for sharding propagation:",
31-
[{::llvm::cl::values(
32-
clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward",
33-
"Forward only traversal."),
34-
clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward",
35-
"backward only traversal."),
36-
clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward",
37-
"forward-backward traversal."),
38-
clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward",
39-
"backward-forward traversal.")
40-
)}]>,
41-
];
4227
let dependentDialects = [
4328
"mesh::MeshDialect"
4429
];

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -298,12 +298,13 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
298298
return type;
299299
}
300300

301-
static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
302-
Value &operandValue,
303-
Operation *operandOp,
304-
OpBuilder &builder,
305-
ShardOp &newShardOp) {
301+
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
302+
OpOperand &operand,
303+
OpBuilder &builder,
304+
ShardOp &newShardOp) {
306305
OpBuilder::InsertionGuard insertionGuard(builder);
306+
Value operandValue = operand.get();
307+
Operation *operandOp = operand.getOwner();
307308
builder.setInsertionPointAfterValue(operandValue);
308309
ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
309310
if (shardOp && sharding == shardOp.getSharding() &&
@@ -322,8 +323,9 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
322323
builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
323324
/*annotate_for_users*/ false);
324325
}
325-
operandValue.replaceUsesWithIf(
326-
newShardOp, [operandOp, operandValue](OpOperand &use) {
326+
IRRewriter rewriter(builder);
327+
rewriter.replaceUsesWithIf(
328+
operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
327329
return use.getOwner() == operandOp && use.get() == operandValue;
328330
});
329331

@@ -334,20 +336,15 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
334336
auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
335337
newShardOp.getSharding(),
336338
/*annotate_for_users*/ true);
337-
newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
339+
rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
338340
}
339341

340342
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
341343
OpResult result,
342344
OpBuilder &builder) {
343345
ShardOp newShardOp;
344-
SmallVector<std::pair<Value, Operation *>> uses;
345-
for (auto &use : result.getUses()) {
346-
uses.emplace_back(use.get(), use.getOwner());
347-
}
348-
for (auto &[operandValue, operandOp] : uses) {
349-
maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
350-
builder, newShardOp);
346+
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
347+
maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
351348
}
352349
}
353350

mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,6 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
362362
//===----------------------------------------------------------------------===//
363363
struct ShardingPropagation
364364
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
365-
366-
using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
367-
368365
void runOnOperation() override {
369366
FunctionOpInterface funcOp = getOperation();
370367
MLIRContext *ctx = funcOp.getContext();
@@ -385,31 +382,18 @@ struct ShardingPropagation
385382
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
386383
});
387384

388-
auto traverse = [&](auto &&range, OpBuilder &builder,
389-
const char *order) -> bool {
390-
for (Operation &op : range) {
391-
if (failed(visitOp(&op, builder))) {
392-
signalPassFailure();
393-
return true;
394-
}
395-
}
396-
LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n"
397-
<< funcOp << "\n");
398-
LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
399-
return false;
400-
};
401-
402-
// 1. Propagate in reversed order.
403-
if (traversal == TraversalOrder::Backward ||
404-
traversal == TraversalOrder::BackwardForward)
405-
traverse(llvm::reverse(block), builder, "backward");
406-
407-
// 2. Propagate in original order.
408-
if (traversal != TraversalOrder::Backward)
409-
traverse(block, builder, "forward");
410-
411-
// 3. Propagate in backward order if needed.
412-
if (traversal == TraversalOrder::ForwardBackward)
413-
traverse(llvm::reverse(block), builder, "backward");
385+
// 1. propagate in reversed order
386+
for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
387+
if (failed(visitOp(&op, builder)))
388+
return signalPassFailure();
389+
390+
LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
391+
<< funcOp << "\n");
392+
LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
393+
394+
// 2. propagate in original order
395+
for (Operation &op : llvm::make_early_inc_range(block))
396+
if (failed(visitOp(&op, builder)))
397+
return signalPassFailure();
414398
}
415399
};

mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir

Lines changed: 0 additions & 26 deletions
This file was deleted.

mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir

Lines changed: 0 additions & 27 deletions
This file was deleted.

mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir

Lines changed: 0 additions & 49 deletions
This file was deleted.

0 commit comments

Comments
 (0)