@@ -49,9 +49,41 @@ struct FoldExpandOfRankReducingExtract
49
49
return success ();
50
50
}
51
51
};
52
+
53
+ // / Fold insert_slice(collapse_shape) ops that cancel itself out.
54
+ struct FoldInsertOfRankReducingInsert : public OpRewritePattern <InsertSliceOp> {
55
+ using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
56
+
57
+ LogicalResult matchAndRewrite (InsertSliceOp insertSliceOp,
58
+ PatternRewriter &rewriter) const override {
59
+ auto collapseShapeOp =
60
+ insertSliceOp.getSource ().getDefiningOp <CollapseShapeOp>();
61
+ if (!collapseShapeOp)
62
+ return failure ();
63
+ RankedTensorType srcType = collapseShapeOp.getSrcType ();
64
+
65
+ // Only cases where the CollapseShapeOp can be folded away entirely are
66
+ // supported. Moreover, only simple cases where the resulting InsertSliceOp
67
+ // has no rank-reduction anymore are supported at the moment.
68
+ RankedTensorType nonReducingInsertType =
69
+ RankedTensorType::get (insertSliceOp.getStaticSizes (),
70
+ insertSliceOp.getType ().getElementType ());
71
+ if (nonReducingInsertType != srcType)
72
+ return failure ();
73
+
74
+ SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets ();
75
+ SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes ();
76
+ SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides ();
77
+ rewriter.replaceOpWithNewOp <tensor::InsertSliceOp>(
78
+ insertSliceOp, collapseShapeOp.getSrc (), insertSliceOp.getDest (),
79
+ mixedOffsets, mixedSizes, mixedStrides);
80
+ return success ();
81
+ }
82
+ };
52
83
} // namespace
53
84
54
85
void mlir::tensor::populateReassociativeReshapeFoldingPatterns (
55
86
RewritePatternSet &patterns) {
56
- patterns.add <FoldExpandOfRankReducingExtract>(patterns.getContext ());
87
+ patterns.add <FoldExpandOfRankReducingExtract, FoldInsertOfRankReducingInsert>(
88
+ patterns.getContext ());
57
89
}
0 commit comments