Skip to content

Commit 2f664f2

Browse files
authored
[mlir][mesh] Fix empty split_axes sharding annotation (#108236)
The `split_axes` attribute is defined as "array attribute of array attributes". Following the definition, empty `split_axes` values should not be allowed, since that would break the definition and would lead to invalid IR. In such scenario, passes leveraging the mesh dialect can observe: * crashes in sharding-propagation; * creation of null MeshShardingAttrs in spmdization; * non roundtrippable IR. The patch prevents `split_axes` to become empty by modifying the `removeTrailingEmptySubArray` such that a minimum size of one is guaranteed when constructing the attribute, and adds a test that would crash without the change.
1 parent 041b0a8 commit 2f664f2

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,10 @@ inline bool isReductionLoop(utils::IteratorType iType) {
9898
return iType == utils::IteratorType::reduction;
9999
}
100100

101+
// Remove empty subarrays of `array` until a minimum lengh of one is reached.
101102
template <typename T>
102103
void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
103-
while (!array.empty() && array.back().empty())
104+
while (array.size() > 1 && array.back().empty())
104105
array.pop_back();
105106
}
106107

mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func.func @matmul_shard_prallel_axis(
1818
// CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32>
1919
// CHECK: %[[SIN1_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
2020
// CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32>
21-
// CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = [] : !mesh.sharding
21+
// CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding
2222
// CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32>
2323
// CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
2424
// CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32>

mlir/test/Dialect/Mesh/sharding-propagation.mlir

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ func.func @resolve_conflicting_annotations(
179179
) -> tensor<2x2xf32> {
180180
// CHECK: %[[SIN1_SHARDED1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}0]] : !mesh.sharding
181181
// CHECK-NEXT: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to %[[SIN1_SHARDED1]] : tensor<2x3xf32>
182-
// CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2 split_axes = [] : !mesh.sharding
182+
// CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}]] : !mesh.sharding
183183
// CHECK-NEXT: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x3xf32>
184184
// CHECK-NEXT: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<3x2xf32>
185185
// CHECK-NEXT: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x2xf32>
@@ -266,3 +266,22 @@ func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x
266266
// CHECK-DAG: return %[[V12]]
267267
return %6 : tensor<2x4x8xf32>
268268
}
269+
270+
// CHECK-LABEL: func.func @elementwise_duplicated_chain
271+
// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
272+
func.func @elementwise_duplicated_chain(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
273+
// CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
274+
// CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
275+
// CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
276+
%0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
277+
// CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
278+
// CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
279+
// CHECK-NEXT: %[[V4:.*]] = tosa.sigmoid %[[V3]]
280+
%1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
281+
// CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
282+
// CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S2]] : tensor<8x16xf32>
283+
%s0 = mesh.sharding @mesh_2d split_axes = [[]] : !mesh.sharding
284+
%2 = mesh.shard %1 to %s0 : tensor<8x16xf32>
285+
// CHECK-NEXT: return %[[V5]]
286+
return %2 : tensor<8x16xf32>
287+
}

0 commit comments

Comments
 (0)