Skip to content

Commit 2fa8004

Browse files
committed
adding option for traversal order in sharding propagation
1 parent fe28ea3 commit 2fa8004

File tree

4 files changed

+106
-10
lines changed

4 files changed

+106
-10
lines changed

mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,18 @@ class FuncOp;
1919

2020
namespace mesh {
2121

22+
/// This enum controls the traversal order for the sharding propagation.
23+
enum class TraversalOrder {
24+
/// Forward traversal.
25+
Forward,
26+
/// Backward traversal.
27+
Backward,
28+
/// Forward then backward traversal.
29+
ForwardBackward,
30+
/// Backward then forward traversal.
31+
BackwardForward
32+
};
33+
2234
//===----------------------------------------------------------------------===//
2335
// Passes
2436
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,21 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
2424
operation, and the operations themselves are added with sharding option
2525
attributes.
2626
}];
27+
let options = [
28+
Option<"traversal", "traversal",
29+
"mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward",
30+
"Traversal order to use for sharding propagation:",
31+
[{::llvm::cl::values(
32+
clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward",
33+
"Forward only traversal."),
34+
clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward",
35+
"backward only traversal."),
36+
clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward",
37+
"forward-backward traversal."),
38+
clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward",
39+
"backward-forward traversal.")
40+
)}]>,
41+
];
2742
let dependentDialects = [
2843
"mesh::MeshDialect"
2944
];

mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
362362
//===----------------------------------------------------------------------===//
363363
struct ShardingPropagation
364364
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
365+
366+
using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
367+
365368
void runOnOperation() override {
366369
FunctionOpInterface funcOp = getOperation();
367370
MLIRContext *ctx = funcOp.getContext();
@@ -383,17 +386,34 @@ struct ShardingPropagation
383386
});
384387

385388
// 1. propagate in reversed order
386-
for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
387-
if (failed(visitOp(&op, builder)))
388-
return signalPassFailure();
389-
390-
LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
391-
<< funcOp << "\n");
392-
LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
389+
if (traversal == TraversalOrder::Backward ||
390+
traversal == TraversalOrder::BackwardForward) {
391+
for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
392+
if (failed(visitOp(&op, builder)))
393+
return signalPassFailure();
394+
if (traversal == TraversalOrder::BackwardForward) {
395+
LLVM_DEBUG(DBGS() << "After backward order propagation:\n"
396+
<< funcOp << "\n");
397+
LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
398+
}
399+
}
393400

394401
// 2. propagate in original order
395-
for (Operation &op : llvm::make_early_inc_range(block))
396-
if (failed(visitOp(&op, builder)))
397-
return signalPassFailure();
402+
if (traversal != TraversalOrder::Backward) {
403+
for (Operation &op : llvm::make_early_inc_range(block))
404+
if (failed(visitOp(&op, builder)))
405+
return signalPassFailure();
406+
if (traversal == TraversalOrder::ForwardBackward) {
407+
LLVM_DEBUG(DBGS() << "After forward order propagation:\n"
408+
<< funcOp << "\n");
409+
LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
410+
}
411+
}
412+
413+
// 3. propagate in backward order if needed
414+
if (traversal == TraversalOrder::ForwardBackward)
415+
for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
416+
if (failed(visitOp(&op, builder)))
417+
return signalPassFailure();
398418
}
399419
};
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s
2+
3+
#map = affine_map<(d0, d1) -> (d0, d1)>
4+
module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} {
5+
mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
6+
func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} {
7+
%c1_i32 = arith.constant 1 : i32
8+
// CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
9+
%0 = tensor.empty() : tensor<6x6xi32>
10+
// CHECK: [[v1:%.*]] = linalg.fill ins
11+
// CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
12+
// CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
13+
%1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32>
14+
%sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
15+
%sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32>
16+
// CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
17+
// CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
18+
// CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
19+
%3 = tensor.empty() : tensor<6x6xi32>
20+
// CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
21+
// CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
22+
// CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
23+
// CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) {
24+
// CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
25+
// CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
26+
%4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated
27+
: tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
28+
^bb0(%in: i32, %in_2: i32, %out: i32):
29+
%9 = arith.addi %in, %in_2 : i32
30+
linalg.yield %9 : i32
31+
} -> tensor<6x6xi32>
32+
%c0_i32 = arith.constant 0 : i32
33+
%6 = tensor.empty() : tensor<i32>
34+
%7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
35+
// CHECK: [[vreduced:%.*]] = linalg.reduce ins
36+
// CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] partial = sum [0] : !mesh.sharding
37+
// CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
38+
%reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1]
39+
(%in: i32, %init: i32) {
40+
%9 = arith.addi %in, %init : i32
41+
linalg.yield %9 : i32
42+
}
43+
// CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding
44+
%sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
45+
// CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
46+
%sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
47+
return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>
48+
}
49+
}

0 commit comments

Comments
 (0)