Skip to content

[mlir] Fix ComposeExpandOfCollapseOp for dynamic case #142663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 11, 2025

Conversation

IanWood1
Copy link
Contributor

@IanWood1 IanWood1 commented Jun 3, 2025

Changes findCollapsingReassociation to return nullopt in all cases where source shape has >=2 dynamic dims. expand(collapse) can reshape to in any valid output shape but a collapse can only collapse contiguous dimensions. When there are >=2 dynamic dimensions it is impossible to determine if it can be simplified to a collapse or if it is preforming a more advanced reassociation.

This problem was uncovered by #137963

@llvmbot
Copy link
Member

llvmbot commented Jun 3, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Ian Wood (IanWood1)

Changes

Changes findCollapsingReassociation to return nullopt in all cases where source shape has >=2 dynamic dims. expand(collapse) can reshape to in any valid output shape but a single collapse can only collapse contiguous dimensions. When there are >=2 dynamic dimensions it is impossible to determine if it can be simplified to a collapse or if it is preforming a more advanced reassociation.


Full diff: https://github.com/llvm/llvm-project/pull/142663.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+5-3)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+14)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index af575e10acc8e..c9f29ab8f15d5 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -387,11 +387,13 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
       auto resultSubShape =
           resultShape.slice(resultIndices.front(), resultIndices.size());
 
+      if (llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2)
+        return std::nullopt;
+
       if (srcSubShape.size() == resultSubShape.size()) {
-        if (srcSubShape != resultSubShape ||
-            llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) {
+        if (srcSubShape != resultSubShape)
           return std::nullopt;
-        }
+
         for (auto index : llvm::seq<int64_t>(0, srcSubShape.size())) {
           composedReassociation.emplace_back(1, srcIndices.front() + index);
         }
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 646b2197d9aa6..7e5e57423f44d 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1319,6 +1319,20 @@ func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %
 
 // -----
 
+func.func @no_compose_collapse_of_expand_dynamic(%arg0 : tensor<?x8x128x?xf16>, %arg1: index) -> tensor<?x128x?xf16> {
+  %collapse = tensor.collapse_shape %arg0 [[0, 1, 2, 3]] : tensor<?x8x128x?xf16> into tensor<?xf16>
+  %expanded_19 = tensor.expand_shape %collapse [[0, 1, 2]] output_shape [%arg1, 8, %arg1] : tensor<?xf16> into tensor<?x128x?xf16>
+  return %expanded_19 : tensor<?x128x?xf16>
+}
+// CHECK-LABEL: func @no_compose_collapse_of_expand_dynamic
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor
+//  CHECK-SAME:   %[[ARG1:.+]]: index
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]]
+//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]]
+//       CHECK:   return %[[EXPAND]]
+
+// -----
+
 // CHECK-LABEL: func @zero_rank_reshape_multi
 func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
   // CHECK: return %arg0

Copy link
Contributor

@AGindinson AGindinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after the previous thread, can't approve due to commit access still pending

@IanWood1 IanWood1 merged commit 6f2ba47 into llvm:main Jun 11, 2025
7 checks passed
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
Changes `findCollapsingReassociation` to return nullopt in all cases
where source shape has `>=2` dynamic dims. `expand(collapse)` can
reshape to in any valid output shape but a collapse can only collapse
contiguous dimensions. When there are `>=2` dynamic dimensions it is
impossible to determine if it can be simplified to a collapse or if it
is preforming a more advanced reassociation.


This problem was uncovered by
llvm#137963

---------

Signed-off-by: Ian Wood <[email protected]>
akuhlens pushed a commit to akuhlens/llvm-project that referenced this pull request Jun 24, 2025
Changes `findCollapsingReassociation` to return nullopt in all cases
where source shape has `>=2` dynamic dims. `expand(collapse)` can
reshape to in any valid output shape but a collapse can only collapse
contiguous dimensions. When there are `>=2` dynamic dimensions it is
impossible to determine if it can be simplified to a collapse or if it
is preforming a more advanced reassociation.


This problem was uncovered by
llvm#137963

---------

Signed-off-by: Ian Wood <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants