Skip to content

Commit 634e253

Browse files
authored
[mlir] Add special case for 0-D tensor when fusing expand from collapse (#130838)
One fusion pattern for collapse_shape -> expand_shape was added in a95ad2d, however if the intermediate tensor between a collapse and expand is a 0-D tensor, then the `reassociation_map` for these two are special cases and can't be generally fused in this function `BubbleUpExpandThroughParallelCollapse`.
1 parent 701148f commit 634e253

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,12 @@ struct BubbleUpExpandThroughParallelCollapse
160160
auto expandReInds = expandOp.getReassociationIndices();
161161
auto collapseReInds = collapseOp.getReassociationIndices();
162162

163+
// Special case where the collapsed tensor to expand is a 0-D tensor,
164+
// then the reassociation maps will be empty and not produce valid results.
165+
if (expandReInds.size() == 0) {
166+
return failure();
167+
}
168+
163169
// Reshapes are parallel to each other if none of the reassociation indices
164170
// have greater than 1 index for both reshapes.
165171
for (auto [expandReassociation, collapseReassociation] :

mlir/test/Dialect/Tensor/bubble-reshapes.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,17 @@ func.func @no_bubble_partial_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %
4545
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
4646
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1], [2, 3]]
4747
// CHECK: return %[[EXPAND]]
48+
49+
// -----
50+
51+
func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor<?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
52+
%collapse = tensor.collapse_shape %arg0 [] : tensor<?xf32> into tensor<f32>
53+
%expand = tensor.expand_shape %collapse []
54+
output_shape [%s0, %s1, %s2, %s3] : tensor<f32> into tensor<?x?x?x?xf32>
55+
return %expand : tensor<?x?x?x?xf32>
56+
}
57+
// CHECK: func @no_bubble_0d_tensor_reshapes
58+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
59+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}]
60+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}]
61+
// CHECK: return %[[EXPAND]]

0 commit comments

Comments
 (0)