Skip to content

Commit e165320

Browse files
committed
adding tests for forward and forward-backward sharding propagation
1 parent 73e1dca commit e165320

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=backward}))" %s | FileCheck %s
2+
3+
#map = affine_map<(d0, d1) -> (d0, d1)>
4+
module {
5+
mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
6+
func.func @test_forward() -> tensor<6x6xi32> {
7+
%c1_i32 = arith.constant 1 : i32
8+
// CHECK: tensor.empty()
9+
%0 = tensor.empty() : tensor<6x6xi32>
10+
%sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
11+
// CHECK-COUNT-2: mesh.shard
12+
%sharding_annotated = mesh.shard %0 to %sharding : tensor<6x6xi32>
13+
%1 = linalg.fill ins(%c1_i32 : i32) outs(%sharding_annotated : tensor<6x6xi32>) -> tensor<6x6xi32>
14+
// CHECK: tensor.empty()
15+
// CHECK-NOT: mesh.shard @
16+
%2 = tensor.empty() : tensor<6x6xi32>
17+
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %1
18+
: tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
19+
^bb0(%in: i32, %in_2: i32, %out: i32):
20+
%9 = arith.addi %in, %in_2 : i32
21+
linalg.yield %9 : i32
22+
} -> tensor<6x6xi32>
23+
// CHECK: return
24+
return %3 : tensor<6x6xi32>
25+
}
26+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward-backward}))" %s | FileCheck %s
2+
3+
#map = affine_map<(d0, d1) -> (d0, d1)>
4+
module {
5+
mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
6+
func.func @test_forward() -> tensor<6x6xi32> {
7+
%c1_i32 = arith.constant 1 : i32
8+
// CHECK: tensor.empty()
9+
%0 = tensor.empty() : tensor<6x6xi32>
10+
// CHECK-COUNT-3: mesh.sharding @mesh split_axes = {{\[\[0}}]]
11+
%sharding_row = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
12+
%annotated_row = mesh.shard %0 to %sharding_row : tensor<6x6xi32>
13+
%1 = linalg.fill ins(%c1_i32 : i32) outs(%annotated_row : tensor<6x6xi32>) -> tensor<6x6xi32>
14+
%2 = tensor.empty() : tensor<6x6xi32>
15+
// CHECK-COUNT-4: mesh.sharding @mesh split_axes = {{\[\[1}}]]
16+
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %1
17+
: tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
18+
^bb0(%in: i32, %in_2: i32, %out: i32):
19+
%9 = arith.addi %in, %in_2 : i32
20+
linalg.yield %9 : i32
21+
} -> tensor<6x6xi32>
22+
%sharding_col = mesh.sharding @mesh split_axes = [[1]] : !mesh.sharding
23+
%annotated_col = mesh.shard %3 to %sharding_col : tensor<6x6xi32>
24+
// CHECK: return
25+
return %annotated_col : tensor<6x6xi32>
26+
}
27+
}

0 commit comments

Comments
 (0)