Skip to content

Commit 275a2b0

Browse files
authored
[MLIR][Tensor] Perform shape inference via in-place modification (NFC) (#111593)
This is more efficient to avoid a clone that is immediately removed. Also guard the insertion of a cast on the result on whether the destination type changed.
1 parent e2dc50c commit 275a2b0

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4332,21 +4332,25 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
43324332
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
43334333
}
43344334
Value dest = packOp.getDest();
4335-
if (destShape != packOp.getDestType().getShape()) {
4335+
RankedTensorType originalResultType = packOp.getDestType();
4336+
bool needUpdateDestType = (destShape != originalResultType.getShape());
4337+
if (needUpdateDestType) {
43364338
auto newDestType = packOp.getDestType().clone(destShape);
43374339
dest =
43384340
rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
43394341
}
4340-
auto clonedPackOp = cast<PackOp>(rewriter.clone(*packOp));
4341-
Value res = clonedPackOp.getResult();
4342-
rewriter.startOpModification(clonedPackOp);
4343-
clonedPackOp.getSourceMutable().assign(source);
4344-
clonedPackOp.getDestMutable().assign(dest);
4345-
res.setType(dest.getType());
4346-
rewriter.finalizeOpModification(clonedPackOp);
4347-
4348-
rewriter.replaceOpWithNewOp<tensor::CastOp>(
4349-
packOp, packOp.getResult().getType(), clonedPackOp);
4342+
rewriter.modifyOpInPlace(packOp, [&] {
4343+
packOp.getSourceMutable().assign(source);
4344+
packOp.getDestMutable().assign(dest);
4345+
packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
4346+
});
4347+
// Insert a cast if needed
4348+
if (needUpdateDestType) {
4349+
rewriter.setInsertionPointAfter(packOp);
4350+
auto castOp =
4351+
rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
4352+
rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
4353+
}
43504354
return success();
43514355
}
43524356

0 commit comments

Comments
 (0)