-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Tensor] Perform shape inference via in-place modification (NFC) #111593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is more efficient to avoid a clone that is immediately removed. Also guard the insertion of a cast on the result on whether the destination type changed.
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Mehdi Amini (joker-eph) ChangesThis is more efficient to avoid a clone that is immediately removed. Full diff: https://github.com/llvm/llvm-project/pull/111593.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 659eabd2e93880..0ac0899def21b5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4332,21 +4332,24 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
+ Type originalResultType = dest.getType();
if (destShape != packOp.getDestType().getShape()) {
auto newDestType = packOp.getDestType().clone(destShape);
dest =
rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
}
- auto clonedPackOp = cast<PackOp>(rewriter.clone(*packOp));
- Value res = clonedPackOp.getResult();
- rewriter.startOpModification(clonedPackOp);
- clonedPackOp.getSourceMutable().assign(source);
- clonedPackOp.getDestMutable().assign(dest);
- res.setType(dest.getType());
- rewriter.finalizeOpModification(clonedPackOp);
-
- rewriter.replaceOpWithNewOp<tensor::CastOp>(
- packOp, packOp.getResult().getType(), clonedPackOp);
+ rewriter.modifyOpInPlace(packOp, [&] {
+ packOp.getSourceMutable().assign(source);
+ packOp.getDestMutable().assign(dest);
+ packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+ });
+ // Insert a cast if needed
+ if (originalResultType != dest.getType()) {
+ rewriter.setInsertionPointAfter(packOp);
+ auto castOp =
+ rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
+ rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+ }
return success();
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both versions look good to me, thanks!
Thanks, @joker-eph, for the cleanup. |
This is more efficient to avoid a clone that is immediately removed.
Also guard the insertion of a cast on the result on whether the destination type changed.