Skip to content

Commit 1403073

Browse files
[mlir][tensor] Fold rank-reducing insert_slice with inverse collapse_shape
Differential Revision: https://reviews.llvm.org/D139221
1 parent 50a2bb9 commit 1403073

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,41 @@ struct FoldExpandOfRankReducingExtract
4949
return success();
5050
}
5151
};
52+
53+
/// Fold insert_slice(collapse_shape) ops that cancel itself out.
54+
struct FoldInsertOfRankReducingInsert : public OpRewritePattern<InsertSliceOp> {
55+
using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
56+
57+
LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
58+
PatternRewriter &rewriter) const override {
59+
auto collapseShapeOp =
60+
insertSliceOp.getSource().getDefiningOp<CollapseShapeOp>();
61+
if (!collapseShapeOp)
62+
return failure();
63+
RankedTensorType srcType = collapseShapeOp.getSrcType();
64+
65+
// Only cases where the CollapseShapeOp can be folded away entirely are
66+
// supported. Moreover, only simple cases where the resulting InsertSliceOp
67+
// has no rank-reduction anymore are supported at the moment.
68+
RankedTensorType nonReducingInsertType =
69+
RankedTensorType::get(insertSliceOp.getStaticSizes(),
70+
insertSliceOp.getType().getElementType());
71+
if (nonReducingInsertType != srcType)
72+
return failure();
73+
74+
SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
75+
SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
76+
SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
77+
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
78+
insertSliceOp, collapseShapeOp.getSrc(), insertSliceOp.getDest(),
79+
mixedOffsets, mixedSizes, mixedStrides);
80+
return success();
81+
}
82+
};
5283
} // namespace
5384

5485
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
5586
RewritePatternSet &patterns) {
56-
patterns.add<FoldExpandOfRankReducingExtract>(patterns.getContext());
87+
patterns.add<FoldExpandOfRankReducingExtract, FoldInsertOfRankReducingInsert>(
88+
patterns.getContext());
5789
}

mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,19 @@ func.func @expand_shape_of_rank_reducing_extract(
1717
: tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
1818
return %1, %2 : tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>
1919
}
20+
21+
// -----
22+
23+
// CHECK-LABEL: func @rank_reducing_insert_of_collapse_shape(
24+
// CHECK-SAME: %[[t:.*]]: tensor<?x1x1x5xf32>
25+
// CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32>
26+
// CHECK: return %[[insert]]
27+
func.func @rank_reducing_insert_of_collapse_shape(
28+
%t: tensor<?x1x1x5xf32>, %d: tensor<?x?x?x?xf32>, %sz: index)
29+
-> tensor<?x?x?x?xf32> {
30+
%0 = tensor.collapse_shape %t [[0, 1], [2], [3]]
31+
: tensor<?x1x1x5xf32> into tensor<?x1x5xf32>
32+
%1 = tensor.insert_slice %0 into %d[0, 0, 0, 0][%sz, 1, 1, 5][1, 1, 1, 1]
33+
: tensor<?x1x5xf32> into tensor<?x?x?x?xf32>
34+
return %1 : tensor<?x?x?x?xf32>
35+
}

0 commit comments

Comments
 (0)