-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][mesh] Fix empty split_axes
sharding annotation
#108236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][mesh] Fix empty split_axes
sharding annotation
#108236
Conversation
The `split_axes` attribute is defined as "array attribute of array attributes". Following the definition, empty `split_axes` values should not be allowed, since that would break the definition and would lead to invalid IR. In such scenario, passes leveraging the mesh dialect can observe: * crashes in sharding-propagation; * creation of null MeshShardingAttrs in spmdization; * non roundtrippable IR. The patch prevents `split_axes` to become empty by modifying the `removeTrailingEmptySubArray` such that a minimum size of one is guaranteed when constructing the attribute, and adds a test that would crash without the change.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Matteo Franciolini (mfrancio) ChangesThe
The patch prevents Full diff: https://github.com/llvm/llvm-project/pull/108236.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 683975bbf215ed..db7b64fda57d7b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -98,9 +98,10 @@ inline bool isReductionLoop(utils::IteratorType iType) {
return iType == utils::IteratorType::reduction;
}
+// Remove empty subarrays of `array` until a minimum lengh of one is reached.
template <typename T>
void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
- while (!array.empty() && array.back().empty())
+ while (array.size() > 1 && array.back().empty())
array.pop_back();
}
diff --git a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
index f8521165e3244e..5297eeb666c1e1 100644
--- a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
@@ -18,7 +18,7 @@ func.func @matmul_shard_prallel_axis(
// CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32>
// CHECK: %[[SIN1_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
// CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32>
- // CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = [] : !mesh.sharding
+ // CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding
// CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32>
// CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
// CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32>
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 5b00b45653dbb6..83136f613b020a 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -179,7 +179,7 @@ func.func @resolve_conflicting_annotations(
) -> tensor<2x2xf32> {
// CHECK: %[[SIN1_SHARDED1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK-NEXT: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to %[[SIN1_SHARDED1]] : tensor<2x3xf32>
- // CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2 split_axes = [] : !mesh.sharding
+ // CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}]] : !mesh.sharding
// CHECK-NEXT: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x3xf32>
// CHECK-NEXT: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<3x2xf32>
// CHECK-NEXT: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x2xf32>
@@ -266,3 +266,22 @@ func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x
// CHECK-DAG: return %[[V12]]
return %6 : tensor<2x4x8xf32>
}
+
+// CHECK-LABEL: func.func @elementwise_duplicated_chain
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @elementwise_duplicated_chain(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
+ %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V4:.*]] = tosa.sigmoid %[[V3]]
+ %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+ // CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S2]] : tensor<8x16xf32>
+ %s0 = mesh.sharding @mesh_2d split_axes = [[]] : !mesh.sharding
+ %2 = mesh.shard %1 to %s0 : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V5]]
+ return %2 : tensor<8x16xf32>
+}
|
split_axes
sharding annotation.split_axes
sharding annotation
ping for feedback @joker-eph, @sogartar? Sharding propagation and spmdization are severely affected by this bug. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is consistent with the doc.
I can see why it is annoying and error prone that every usage of the attribute must check first if there are any subarrays in the array, but requiring the existence of the single empty subarray looks like a quirk.
Thanks for the review. I think it is logically sound (if you expect ArrayAttr of Attr, then you must have an attribute as part of your array), and at the same time counter intuitive, since the semantic of the spec is "absence of attribute means replication". I'd like to try proposing something to make this more ergonomic whenever i have some cycles, but this patch makes this usable at least. I'll wait a couple of days for @joker-eph to give feedback as well before merging. |
The
split_axes
attribute is defined as "array attribute of arrayattributes". Following the definition, empty
split_axes
values shouldnot be allowed, since that would break the definition and would lead to
invalid IR. In such scenario, passes leveraging the mesh dialect can
observe:
The patch prevents
split_axes
to become empty by modifying theremoveTrailingEmptySubArray
such that a minimum size of one isguaranteed when constructing the attribute, and adds a test that would
crash without the change.