Skip to content

[MLIR] Improve compose expand(collapse) pattern #117768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 2, 2024

Conversation

IanWood1
Copy link
Contributor

@IanWood1 IanWood1 commented Nov 26, 2024

If expand(collapse) has a dimension that gets collapsed and then expanded to the same shape, the pattern would fail to canonicalize this to a single collapse shape. Line 341 was changed because the expand(collapse) could be a reinterpret-cast like sequence where the shapes differ but the rank is the same. This cannot be represented by a single collapse_shape op.

@IanWood1 IanWood1 requested review from qedawkins, pashu123 and Max191 and removed request for qedawkins and pashu123 November 26, 2024 19:02
@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2024

@llvm/pr-subscribers-mlir-tensor

Author: Ian Wood (IanWood1)

Changes

If expand(collapse) has a dimension that gets collapsed and then expanded to the same shape, the pattern would fail to canonicalize this to a single collapse shape. Line 341 was changed because the expand(collapse) could be a reinterpret-cast like sequence where the shapes differ but the rank is the same. This cannot be represented by a single collapse_shape op and should be converted to a cast.


Full diff: https://github.com/llvm/llvm-project/pull/117768.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+14-10)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+28)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 89bc57f09ec8ba..0357e34a2e0963 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -338,7 +338,7 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
 
     int64_t srcRank = srcType.getRank();
     int64_t resultRank = resultType.getRank();
-    if (srcType == resultType)
+    if (srcRank == resultRank)
       return failure();
 
     auto srcReassociation = collapseOp.getReassociationIndices();
@@ -388,12 +388,16 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
           resultShape.slice(resultIndices.front(), resultIndices.size());
 
       if (srcSubShape.size() == resultSubShape.size()) {
-        if (srcSubShape == resultSubShape &&
-            llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
-          composedReassociation.push_back(srcIndices);
-        } else {
+        if (srcSubShape != resultSubShape ||
+            llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) {
           return std::nullopt;
         }
+        for (auto dim : llvm::seq<int64_t>(0, srcSubShape.size())) {
+          ReassociationIndices reassoc;
+          reassoc.push_back(srcIndices.front() + dim);
+          composedReassociation.push_back(reassoc);
+        }
+        continue;
       }
 
       // Find reassociation to collapse `srcSubShape` into `resultSubShape`.
@@ -403,11 +407,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
         return std::nullopt;
 
       // Remap the subshape indices back to the original srcShape.
-      for (auto &subshape_indices : *subShapeReassociation) {
-        ReassociationIndices shape_indices;
-        for (int64_t index : subshape_indices)
-          shape_indices.push_back(srcIndices.front() + index);
-        composedReassociation.push_back(shape_indices);
+      for (auto &subshapeIndices : *subShapeReassociation) {
+        ReassociationIndices shapeIndices;
+        for (int64_t index : subshapeIndices)
+          shapeIndices.push_back(srcIndices.front() + index);
+        composedReassociation.push_back(shapeIndices);
       }
     }
     return {std::move(composedReassociation)};
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0b54c207dea84e..613ec066337294 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1382,6 +1382,34 @@ func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1x
 
 // -----
 
+func.func @compose_expand_of_collapse_static(%arg0 : tensor<4x32x10x64x2xf16>) -> tensor<4x32x10x128xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x32x10x64x2xf16> into tensor<128x10x128xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 32, 10, 128] : tensor<128x10x128xf16> into tensor<4x32x10x128xf16>
+  return %expanded : tensor<4x32x10x128xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_static
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<4x32x10x64x2xf16>
+//      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1], [2], [3, 4]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %arg1 : index) -> tensor<4x?x10x128xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x?x10x64x2xf16> into tensor<?x10x128xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, %arg1,  10, 128] : tensor<?x10x128xf16> into tensor<4x?x10x128xf16>
+  return %expanded : tensor<4x?x10x128xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_dynamic
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<4x?x10x64x2xf16>
+//      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1], [2], [3, 4]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
 // CHECK-LABEL: func @zero_rank_reshape_multi
 func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
   // CHECK: return %arg0

@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2024

@llvm/pr-subscribers-mlir

Author: Ian Wood (IanWood1)

Changes

If expand(collapse) has a dimension that gets collapsed and then expanded to the same shape, the pattern would fail to canonicalize this to a single collapse shape. Line 341 was changed because the expand(collapse) could be a reinterpret-cast like sequence where the shapes differ but the rank is the same. This cannot be represented by a single collapse_shape op and should be converted to a cast.


Full diff: https://github.com/llvm/llvm-project/pull/117768.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+14-10)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+28)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 89bc57f09ec8ba..0357e34a2e0963 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -338,7 +338,7 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
 
     int64_t srcRank = srcType.getRank();
     int64_t resultRank = resultType.getRank();
-    if (srcType == resultType)
+    if (srcRank == resultRank)
       return failure();
 
     auto srcReassociation = collapseOp.getReassociationIndices();
@@ -388,12 +388,16 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
           resultShape.slice(resultIndices.front(), resultIndices.size());
 
       if (srcSubShape.size() == resultSubShape.size()) {
-        if (srcSubShape == resultSubShape &&
-            llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
-          composedReassociation.push_back(srcIndices);
-        } else {
+        if (srcSubShape != resultSubShape ||
+            llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) {
           return std::nullopt;
         }
+        for (auto dim : llvm::seq<int64_t>(0, srcSubShape.size())) {
+          ReassociationIndices reassoc;
+          reassoc.push_back(srcIndices.front() + dim);
+          composedReassociation.push_back(reassoc);
+        }
+        continue;
       }
 
       // Find reassociation to collapse `srcSubShape` into `resultSubShape`.
@@ -403,11 +407,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
         return std::nullopt;
 
       // Remap the subshape indices back to the original srcShape.
-      for (auto &subshape_indices : *subShapeReassociation) {
-        ReassociationIndices shape_indices;
-        for (int64_t index : subshape_indices)
-          shape_indices.push_back(srcIndices.front() + index);
-        composedReassociation.push_back(shape_indices);
+      for (auto &subshapeIndices : *subShapeReassociation) {
+        ReassociationIndices shapeIndices;
+        for (int64_t index : subshapeIndices)
+          shapeIndices.push_back(srcIndices.front() + index);
+        composedReassociation.push_back(shapeIndices);
       }
     }
     return {std::move(composedReassociation)};
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0b54c207dea84e..613ec066337294 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1382,6 +1382,34 @@ func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1x
 
 // -----
 
+func.func @compose_expand_of_collapse_static(%arg0 : tensor<4x32x10x64x2xf16>) -> tensor<4x32x10x128xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x32x10x64x2xf16> into tensor<128x10x128xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 32, 10, 128] : tensor<128x10x128xf16> into tensor<4x32x10x128xf16>
+  return %expanded : tensor<4x32x10x128xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_static
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<4x32x10x64x2xf16>
+//      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1], [2], [3, 4]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %arg1 : index) -> tensor<4x?x10x128xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x?x10x64x2xf16> into tensor<?x10x128xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, %arg1,  10, 128] : tensor<?x10x128xf16> into tensor<4x?x10x128xf16>
+  return %expanded : tensor<4x?x10x128xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_dynamic
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<4x?x10x64x2xf16>
+//      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1], [2], [3, 4]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
 // CHECK-LABEL: func @zero_rank_reshape_multi
 func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
   // CHECK: return %arg0

If expand(collapse) has a dimension that gets collapsed and then
expanded to the same shape, the pattern would fail to canonicalize this
to a single collapse shape.

Signed-off-by: Ian Wood <[email protected]>
@IanWood1 IanWood1 merged commit fcfdabf into llvm:main Dec 2, 2024
8 checks passed
@IanWood1 IanWood1 deleted the improve_reshape_canon branch December 2, 2024 16:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants