@@ -1188,10 +1188,32 @@ mlir::scf::tileAndFuseProducerOfSlice(
1188
1188
clonedProducerOp->getResult (resultNumber));
1189
1189
if (failed (tileAndFuseResult))
1190
1190
return std::nullopt;
1191
- // Note: Do not delete the candidateSliceOp, since its passed in from the
1192
- // caller.
1193
- rewriter.replaceAllUsesWith (candidateSliceOp,
1194
- tileAndFuseResult->tiledValues [0 ]);
1191
+
1192
+ // Check if the types are the same. If possible insert a cast. Fail otherwise.
1193
+ if (tileAndFuseResult->tiledValues [0 ].getType () !=
1194
+ candidateSliceOp.getResult ().getType ()) {
1195
+ auto tileAndFuseResultType =
1196
+ cast<RankedTensorType>(tileAndFuseResult->tiledValues [0 ].getType ());
1197
+ auto candidateSliceOpType =
1198
+ cast<RankedTensorType>(candidateSliceOp.getResult ().getType ());
1199
+ // We can only cast if the tileAndFuseResultType has a static shape and
1200
+ // canidateSliceOp has a dynamic shape. Might be expanded in the future.
1201
+ if (!tileAndFuseResultType.hasStaticShape () ||
1202
+ candidateSliceOpType.hasStaticShape ()) {
1203
+ return std::nullopt;
1204
+ }
1205
+
1206
+ auto castOp = rewriter.create <tensor::CastOp>(
1207
+ candidateSliceOp->getLoc (), candidateSliceOpType, tileAndFuseResult->tiledValues [0 ]);
1208
+ // Note: Do not delete the candidateSliceOp, since its passed in from the
1209
+ // caller.
1210
+ rewriter.replaceAllUsesWith (candidateSliceOp, castOp);
1211
+ } else {
1212
+ // Note: Do not delete the candidateSliceOp, since its passed in from the
1213
+ // caller.
1214
+ rewriter.replaceAllUsesWith (candidateSliceOp,
1215
+ tileAndFuseResult->tiledValues [0 ]);
1216
+ }
1195
1217
rewriter.eraseOp (clonedCandidateSliceOp);
1196
1218
rewriter.eraseOp (clonedProducerOp);
1197
1219
0 commit comments