Skip to content

Commit fcfdabf

Browse files
authored
[MLIR] Improve compose expand(collapse) pattern (#117768)
If expand(collapse) has a dimension that gets collapsed and then expanded to the same shape, the pattern would fail to canonicalize this to a single collapse shape. Line 341 was changed because the expand(collapse) could be a reinterpret-cast like sequence where the shapes differ but the rank is the same. This cannot be represented by a single `collapse_shape` op. Signed-off-by: Ian Wood <[email protected]>
1 parent 8935165 commit fcfdabf

File tree

2 files changed

+40
-10
lines changed

2 files changed

+40
-10
lines changed

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
338338

339339
int64_t srcRank = srcType.getRank();
340340
int64_t resultRank = resultType.getRank();
341-
if (srcType == resultType)
341+
if (srcRank == resultRank)
342342
return failure();
343343

344344
auto srcReassociation = collapseOp.getReassociationIndices();
@@ -388,12 +388,14 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
388388
resultShape.slice(resultIndices.front(), resultIndices.size());
389389

390390
if (srcSubShape.size() == resultSubShape.size()) {
391-
if (srcSubShape == resultSubShape &&
392-
llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
393-
composedReassociation.push_back(srcIndices);
394-
} else {
391+
if (srcSubShape != resultSubShape ||
392+
llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) {
395393
return std::nullopt;
396394
}
395+
for (auto index : llvm::seq<int64_t>(0, srcSubShape.size())) {
396+
composedReassociation.emplace_back(1, srcIndices.front() + index);
397+
}
398+
continue;
397399
}
398400

399401
// Find reassociation to collapse `srcSubShape` into `resultSubShape`.
@@ -403,11 +405,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
403405
return std::nullopt;
404406

405407
// Remap the subshape indices back to the original srcShape.
406-
for (auto &subshape_indices : *subShapeReassociation) {
407-
ReassociationIndices shape_indices;
408-
for (int64_t index : subshape_indices)
409-
shape_indices.push_back(srcIndices.front() + index);
410-
composedReassociation.push_back(shape_indices);
408+
for (auto &subshapeIndices : *subShapeReassociation) {
409+
ReassociationIndices shapeIndices;
410+
for (int64_t index : subshapeIndices)
411+
shapeIndices.push_back(srcIndices.front() + index);
412+
composedReassociation.push_back(shapeIndices);
411413
}
412414
}
413415
return {std::move(composedReassociation)};

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,6 +1382,34 @@ func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1x
13821382

13831383
// -----
13841384

1385+
func.func @compose_expand_of_collapse_static(%arg0 : tensor<4x32x10x64x2xf16>) -> tensor<4x32x10x128xf16> {
1386+
%collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x32x10x64x2xf16> into tensor<128x10x128xf16>
1387+
%expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 32, 10, 128] : tensor<128x10x128xf16> into tensor<4x32x10x128xf16>
1388+
return %expanded : tensor<4x32x10x128xf16>
1389+
}
1390+
1391+
// CHECK-LABEL: func @compose_expand_of_collapse_static
1392+
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x32x10x64x2xf16>
1393+
// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
1394+
// CHECK-SAME: [0], [1], [2], [3, 4]
1395+
// CHECK: return %[[RESULT]]
1396+
1397+
// -----
1398+
1399+
func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %arg1 : index) -> tensor<4x?x10x128xf16> {
1400+
%collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x?x10x64x2xf16> into tensor<?x10x128xf16>
1401+
%expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, %arg1, 10, 128] : tensor<?x10x128xf16> into tensor<4x?x10x128xf16>
1402+
return %expanded : tensor<4x?x10x128xf16>
1403+
}
1404+
1405+
// CHECK-LABEL: func @compose_expand_of_collapse_dynamic
1406+
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x?x10x64x2xf16>
1407+
// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
1408+
// CHECK-SAME: [0], [1], [2], [3, 4]
1409+
// CHECK: return %[[RESULT]]
1410+
1411+
// -----
1412+
13851413
// CHECK-LABEL: func @zero_rank_reshape_multi
13861414
func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
13871415
// CHECK: return %arg0

0 commit comments

Comments
 (0)