Skip to content

Commit 0536c7e

Browse files
committed
fix bug
1 parent 930a4d7 commit 0536c7e

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,6 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
104104
if (srcType != resultType)
105105
return nullptr;
106106

107-
// If the reshapes are expanding and then collapsing, the ops can be folded
108-
// despite multiple dynamic dimensions.
109-
if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
110-
return reshapeSrcOp.getSrc();
111-
// Otherwise, only 1 dynamic dimension is allowed.
112107
if (srcType == resultType &&
113108
llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
114109
return reshapeSrcOp.getSrc();
@@ -124,6 +119,10 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
124119
auto inverseReassociations = reshapeSrcOp.getReassociationIndices();
125120
if (reassociations != inverseReassociations)
126121
return nullptr;
122+
// If the reshapes are expanding and then collapsing, the ops can be folded
123+
// despite multiple dynamic dimensions.
124+
if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
125+
return reshapeSrcOp.getSrc();
127126
ArrayRef<int64_t> expandedSrcShape = srcType.getShape();
128127
ArrayRef<int64_t> expandedResultShape = resultType.getShape();
129128
if (llvm::none_of(reassociations, [&](auto reInd) {

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,21 @@ func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor<?x?xf32>, %arg1:
11691169

11701170
// -----
11711171

1172+
func.func @no_fold_parallel_collapse_of_expand_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index)
1173+
-> tensor<?x?x?xf32> {
1174+
%0 = tensor.expand_shape %arg0 [[0, 1], [2], [3]] output_shape [%arg1, %arg2, %arg3, %arg4]
1175+
: tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
1176+
%1 = tensor.collapse_shape %0 [[0], [1], [2, 3]]
1177+
: tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
1178+
return %1 : tensor<?x?x?xf32>
1179+
}
1180+
// CHECK-LABEL: @no_fold_parallel_collapse_of_expand_dynamic
1181+
// CHECK: tensor.expand_shape
1182+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
1183+
// CHECK: return %[[COLLAPSE]]
1184+
1185+
// -----
1186+
11721187
func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> {
11731188
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
11741189
: tensor<3x4x4xf32> into tensor<12x4xf32>

0 commit comments

Comments
 (0)