@@ -1085,7 +1085,24 @@ class InsertSliceOpConstantArgumentFolder final
1085
1085
}
1086
1086
};
1087
1087
1088
- // / Fold tensor_casts with insert_slice operations.
1088
+ // / Fold tensor_casts with insert_slice operations. If the source or destination
1089
+ // / tensor is a tensor_cast that removes static type information, the cast is
1090
+ // / folded into the insert_slice operation. E.g.:
1091
+ // /
1092
+ // / ```mlir
1093
+ // / %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
1094
+ // / %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
1095
+ // / ```
1096
+ // /
1097
+ // / folds into:
1098
+ // /
1099
+ // / ```mlir
1100
+ // / %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
1101
+ // / ```
1102
+ // /
1103
+ // / Note: When folding a cast on the destination tensor, the result of the
1104
+ // / insert_slice operation is casted to ensure that the type of the result did
1105
+ // / not change.
1089
1106
struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
1090
1107
using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
1091
1108
@@ -1123,12 +1140,63 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
1123
1140
return success ();
1124
1141
}
1125
1142
};
1143
+
1144
+ // / If additional static type information can be deduced from a insert_slice's
1145
+ // / size operands, insert an explicit cast of the op's source operand. This
1146
+ // / enables other canonicalization patterns that are matching for tensor_cast
1147
+ // / ops such as `ForOpTensorCastFolder` in SCF.
1148
+ // /
1149
+ // / Example:
1150
+ // /
1151
+ // / ```mlir
1152
+ // / %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
1153
+ // / : tensor<?x?xf32> into ...
1154
+ // / ```
1155
+ // /
1156
+ // / folds into:
1157
+ // /
1158
+ // / ```mlir
1159
+ // / %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
1160
+ // / %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
1161
+ // / : tensor<64x64xf32> into ...
1162
+ // / ```
1163
+ struct InsertSliceOpSourceCastInserter final
1164
+ : public OpRewritePattern<InsertSliceOp> {
1165
+ using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
1166
+
1167
+ LogicalResult matchAndRewrite (InsertSliceOp insertSliceOp,
1168
+ PatternRewriter &rewriter) const override {
1169
+ RankedTensorType srcType = insertSliceOp.getSourceType ();
1170
+ if (srcType.getRank () != insertSliceOp.getType ().getRank ())
1171
+ return failure ();
1172
+ SmallVector<int64_t > newSrcShape (srcType.getShape ().begin (),
1173
+ srcType.getShape ().end ());
1174
+ for (int64_t i = 0 ; i < srcType.getRank (); ++i) {
1175
+ if (Optional<int64_t > constInt =
1176
+ getConstantIntValue (insertSliceOp.getMixedSizes ()[i]))
1177
+ newSrcShape[i] = *constInt;
1178
+ }
1179
+ RankedTensorType newSrcType =
1180
+ RankedTensorType::get (newSrcShape, srcType.getElementType ());
1181
+ if (srcType == newSrcType)
1182
+ return failure ();
1183
+
1184
+ // srcType and newSrcType are different. Insert a cast.
1185
+ Value cast = rewriter.create <tensor::CastOp>(
1186
+ insertSliceOp.getLoc (), newSrcType, insertSliceOp.source ());
1187
+ rewriter.replaceOpWithNewOp <InsertSliceOp>(
1188
+ insertSliceOp, cast, insertSliceOp.dest (),
1189
+ insertSliceOp.getMixedOffsets (), insertSliceOp.getMixedSizes (),
1190
+ insertSliceOp.getMixedStrides ());
1191
+ return success ();
1192
+ }
1193
+ };
1126
1194
} // namespace
1127
1195
1128
1196
void InsertSliceOp::getCanonicalizationPatterns (RewritePatternSet &results,
1129
1197
MLIRContext *context) {
1130
- results.add <InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder>(
1131
- context);
1198
+ results.add <InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder,
1199
+ InsertSliceOpSourceCastInserter>( context);
1132
1200
}
1133
1201
1134
1202
// ===----------------------------------------------------------------------===//
0 commit comments