Skip to content

Commit 3ba205e

Browse files
maxbartelDavidGinten
authored andcommitted
fix: account for dynamic sizes while tiling
This is not super safe, when upstreaming we should get feedback here. Also not sure how to test?
1 parent 0bdbc1c commit 3ba205e

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,10 +1188,32 @@ mlir::scf::tileAndFuseProducerOfSlice(
11881188
clonedProducerOp->getResult(resultNumber));
11891189
if (failed(tileAndFuseResult))
11901190
return std::nullopt;
1191-
// Note: Do not delete the candidateSliceOp, since its passed in from the
1192-
// caller.
1193-
rewriter.replaceAllUsesWith(candidateSliceOp,
1194-
tileAndFuseResult->tiledValues[0]);
1191+
1192+
// Check if the types are the same. If possible insert a cast. Fail otherwise.
1193+
if (tileAndFuseResult->tiledValues[0].getType() !=
1194+
candidateSliceOp.getResult().getType()) {
1195+
auto tileAndFuseResultType =
1196+
cast<RankedTensorType>(tileAndFuseResult->tiledValues[0].getType());
1197+
auto candidateSliceOpType =
1198+
cast<RankedTensorType>(candidateSliceOp.getResult().getType());
1199+
// We can only cast if the tileAndFuseResultType has a static shape and
1200+
// canidateSliceOp has a dynamic shape. Might be expanded in the future.
1201+
if (!tileAndFuseResultType.hasStaticShape() ||
1202+
candidateSliceOpType.hasStaticShape()) {
1203+
return std::nullopt;
1204+
}
1205+
1206+
auto castOp = rewriter.create<tensor::CastOp>(
1207+
candidateSliceOp->getLoc(), candidateSliceOpType, tileAndFuseResult->tiledValues[0]);
1208+
// Note: Do not delete the candidateSliceOp, since its passed in from the
1209+
// caller.
1210+
rewriter.replaceAllUsesWith(candidateSliceOp, castOp);
1211+
} else {
1212+
// Note: Do not delete the candidateSliceOp, since its passed in from the
1213+
// caller.
1214+
rewriter.replaceAllUsesWith(candidateSliceOp,
1215+
tileAndFuseResult->tiledValues[0]);
1216+
}
11951217
rewriter.eraseOp(clonedCandidateSliceOp);
11961218
rewriter.eraseOp(clonedProducerOp);
11971219

0 commit comments

Comments
 (0)