Skip to content

[MLIR][Tensor] Perform shape inference via in-place modification (NFC) #111593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 9, 2024

Conversation

joker-eph
Copy link
Collaborator

This is more efficient to avoid a clone that is immediately removed.
Also guard the insertion of a cast on the result on whether the destination type changed.

This is more efficient to avoid a clone that is immediately removed.
Also guard the insertion of a cast on the result on whether the destination
type changed.
@llvmbot
Copy link
Member

llvmbot commented Oct 8, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

Changes

This is more efficient to avoid a clone that is immediately removed.
Also guard the insertion of a cast on the result on whether the destination type changed.


Full diff: https://github.com/llvm/llvm-project/pull/111593.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+13-10)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 659eabd2e93880..0ac0899def21b5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4332,21 +4332,24 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
           rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
     }
     Value dest = packOp.getDest();
+    Type originalResultType = dest.getType();
     if (destShape != packOp.getDestType().getShape()) {
       auto newDestType = packOp.getDestType().clone(destShape);
       dest =
           rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
     }
-    auto clonedPackOp = cast<PackOp>(rewriter.clone(*packOp));
-    Value res = clonedPackOp.getResult();
-    rewriter.startOpModification(clonedPackOp);
-    clonedPackOp.getSourceMutable().assign(source);
-    clonedPackOp.getDestMutable().assign(dest);
-    res.setType(dest.getType());
-    rewriter.finalizeOpModification(clonedPackOp);
-
-    rewriter.replaceOpWithNewOp<tensor::CastOp>(
-        packOp, packOp.getResult().getType(), clonedPackOp);
+    rewriter.modifyOpInPlace(packOp, [&] {
+      packOp.getSourceMutable().assign(source);
+      packOp.getDestMutable().assign(dest);
+      packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+    });
+    // Insert a cast if needed
+    if (originalResultType != dest.getType()) {
+      rewriter.setInsertionPointAfter(packOp);
+      auto castOp =
+          rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
+      rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+    }
     return success();
   }
 

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both versions look good to me, thanks!

@joker-eph joker-eph merged commit 275a2b0 into llvm:main Oct 9, 2024
6 of 9 checks passed
@joker-eph joker-eph deleted the in-place branch October 9, 2024 07:44
@pashu123
Copy link
Member

pashu123 commented Oct 9, 2024

Thanks, @joker-eph, for the cleanup.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants