Skip to content

[mlir][mesh] Fix empty split_axes sharding annotation #108236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 21, 2024

Conversation

mfrancio
Copy link
Contributor

The split_axes attribute is defined as "array attribute of array
attributes". Following the definition, empty split_axes values should
not be allowed, since that would break the definition and would lead to
invalid IR. In such scenario, passes leveraging the mesh dialect can
observe:

  • crashes in sharding-propagation;
  • creation of null MeshShardingAttrs in spmdization;
  • non roundtrippable IR.

The patch prevents split_axes to become empty by modifying the
removeTrailingEmptySubArray such that a minimum size of one is
guaranteed when constructing the attribute, and adds a test that would
crash without the change.

The `split_axes` attribute is defined as "array attribute of array
attributes". Following the definition, empty `split_axes` values should
not be allowed, since that would break the definition and would lead to
invalid IR. In such scenario, passes leveraging the mesh dialect can
observe:
* crashes in sharding-propagation;
* creation of null MeshShardingAttrs in spmdization;
* non roundtrippable IR.

The patch prevents `split_axes` to become empty by modifying the
`removeTrailingEmptySubArray` such that a minimum size of one is
guaranteed when constructing the attribute, and adds a test that would
crash without the change.
@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Matteo Franciolini (mfrancio)

Changes

The split_axes attribute is defined as "array attribute of array
attributes". Following the definition, empty split_axes values should
not be allowed, since that would break the definition and would lead to
invalid IR. In such scenario, passes leveraging the mesh dialect can
observe:

  • crashes in sharding-propagation;
  • creation of null MeshShardingAttrs in spmdization;
  • non roundtrippable IR.

The patch prevents split_axes to become empty by modifying the
removeTrailingEmptySubArray such that a minimum size of one is
guaranteed when constructing the attribute, and adds a test that would
crash without the change.


Full diff: https://github.com/llvm/llvm-project/pull/108236.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+2-1)
  • (modified) mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir (+1-1)
  • (modified) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+20-1)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 683975bbf215ed..db7b64fda57d7b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -98,9 +98,10 @@ inline bool isReductionLoop(utils::IteratorType iType) {
   return iType == utils::IteratorType::reduction;
 }
 
+// Remove empty subarrays of `array` until a minimum lengh of one is reached.
 template <typename T>
 void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
-  while (!array.empty() && array.back().empty())
+  while (array.size() > 1 && array.back().empty())
     array.pop_back();
 }
 
diff --git a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
index f8521165e3244e..5297eeb666c1e1 100644
--- a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
@@ -18,7 +18,7 @@ func.func @matmul_shard_prallel_axis(
   // CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32>
   // CHECK: %[[SIN1_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
   // CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32>
-  // CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = [] : !mesh.sharding
+  // CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding
   // CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32>
   // CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
   // CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32>
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 5b00b45653dbb6..83136f613b020a 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -179,7 +179,7 @@ func.func @resolve_conflicting_annotations(
 ) -> tensor<2x2xf32> {
   // CHECK: %[[SIN1_SHARDED1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}0]] : !mesh.sharding
   // CHECK-NEXT:  %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to %[[SIN1_SHARDED1]]  : tensor<2x3xf32>
-  // CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2 split_axes = [] : !mesh.sharding
+  // CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}]] : !mesh.sharding
   // CHECK-NEXT:  %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users  : tensor<2x3xf32>
   // CHECK-NEXT:  %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users  : tensor<3x2xf32>
   // CHECK-NEXT:  %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users  : tensor<2x2xf32>
@@ -266,3 +266,22 @@ func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x
   // CHECK-DAG: return %[[V12]]
   return %6 : tensor<2x4x8xf32>
 }
+
+// CHECK-LABEL: func.func @elementwise_duplicated_chain
+// CHECK-SAME:    %[[ARG:.*]]: tensor<8x16xf32>
+func.func @elementwise_duplicated_chain(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+  // CHECK-NEXT:  %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users  : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V1:.*]] = tosa.sigmoid %[[V0]]
+  %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V4:.*]] = tosa.sigmoid %[[V3]]
+  %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+  // CHECK-NEXT:  %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+  // CHECK-NEXT:  %[[V5:.*]] = mesh.shard %[[V4]] to %[[S2]]  : tensor<8x16xf32>
+  %s0 = mesh.sharding @mesh_2d split_axes = [[]] : !mesh.sharding
+  %2 = mesh.shard %1 to %s0 : tensor<8x16xf32>
+  // CHECK-NEXT:  return %[[V5]]
+  return %2 : tensor<8x16xf32>
+}

@mfrancio mfrancio requested a review from joker-eph September 11, 2024 14:47
@mfrancio mfrancio changed the title [mlir][mesh] Fix empty split_axes sharding annotation. [mlir][mesh] Fix empty split_axes sharding annotation Sep 13, 2024
@mfrancio
Copy link
Contributor Author

ping for feedback @joker-eph, @sogartar?

Sharding propagation and spmdization are severely affected by this bug.

Copy link
Contributor

@sogartar sogartar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is consistent with the doc.
I can see why it is annoying and error prone that every usage of the attribute must check first if there are any subarrays in the array, but requiring the existence of the single empty subarray looks like a quirk.

@mfrancio
Copy link
Contributor Author

This change is consistent with the doc. I can see why it is annoying and error prone that every usage of the attribute must check first if there are any subarrays in the array, but requiring the existence of the single empty subarray looks like a quirk.

Thanks for the review. I think it is logically sound (if you expect ArrayAttr of Attr, then you must have an attribute as part of your array), and at the same time counter intuitive, since the semantic of the spec is "absence of attribute means replication". I'd like to try proposing something to make this more ergonomic whenever i have some cycles, but this patch makes this usable at least.

I'll wait a couple of days for @joker-eph to give feedback as well before merging.

@mfrancio mfrancio merged commit 2f664f2 into llvm:main Sep 21, 2024
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants