-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Apply InsertSliceOfTransferWriteOpFolder
only when transfer_write
overwrites all elements of insert_slice
#108803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tensor] Apply InsertSliceOfTransferWriteOpFolder
only when transfer_write
overwrites all elements of insert_slice
#108803
Conversation
@MacDue https://mlir.llvm.org/doxygen/classmlir_1_1ValueBoundsConstraintSet.html For static sizes we just check Also, do we want to throw a match failure if it does not overwrite all elements, or just return with |
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Rajveer Singh Bharadwaj (Rajveer100) ChangesResolves #101708 Full diff: https://github.com/llvm/llvm-project/pull/108803.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 5396531922aab3..f7a490844e95af 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -23,6 +23,8 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <cstddef>
+#include <sys/_types/_int64_t.h>
#include <type_traits>
namespace mlir {
@@ -67,6 +69,12 @@ class InsertSliceOfTransferWriteOpFolder final
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
PatternRewriter &rewriter) const override;
+
+private:
+ static bool
+ doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp,
+ tensor::InsertSliceOp insertSliceOp,
+ MLIRContext *context);
};
} // namespace
@@ -136,6 +144,11 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
if (failed(preconditionResult))
return preconditionResult;
+ if (!doesTransferWriteCoverInsertSlice(writeOp, insertSliceOp,
+ rewriter.getContext()))
+ return rewriter.notifyMatchFailure(
+ insertSliceOp, "transfer_write does not cover insert_slice");
+
SmallVector<Value> indices(writeOp.getIndices().begin(),
writeOp.getIndices().end());
SmallVector<Value> sourceIndices;
@@ -154,6 +167,13 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
return success();
}
+bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
+ vector::TransferWriteOp writeOp, tensor::InsertSliceOp insertSliceOp,
+ MLIRContext *context) {
+ // Todo
+ return true;
+}
+
template <typename OpTy>
struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
index 1a84e141049325..7ba24511e96ba5 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
@@ -226,6 +226,15 @@ func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x
// -----
+func.func @insert_slice_of_transfer_write(%arg0: tensor<1000x1000xf32>, %arg1: vector<5x6xf32>, %arg2: index, %arg3: tensor<100x100xf32>) -> tensor<1000x1000xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.transfer_write %arg1, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<100x100xf32>
+ %inserted_slice = tensor.insert_slice %0 into %arg0[3, %arg2] [100, 100] [1, 1] : tensor<100x100xf32> into tensor<1000x1000xf32>
+ return %inserted_slice : tensor<1000x1000xf32>
+}
+
+// -----
+
// CHECK-DAG: #[[$d0d2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-LABEL: func @insert_slice_of_transfer_write_swappy_rank_extending(
|
Yes, my idea where was you check the size of the vector dimensions >= the upper bound on the size of the dynamic tensor destination (if a upper bound can be resolved).
Yep 👍 For a first patch maybe just fix the issue for static sizes and fail for dynamic sizes (where the transform may not be correct). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the PR ready for review?
@hanhanW PS: Added checks for static sizes. |
b5ba23d
to
6b4aec5
Compare
6b4aec5
to
cac2b77
Compare
@MacDue |
cac2b77
to
17a61bc
Compare
17a61bc
to
23def0a
Compare
Do we need an issue for dynamic shapes, or should I just mention it in the following PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of referencing the issue only, can you make the PR description more descriptive? Even just inlining the words from the issue is better, because it saves a level of routing.
23def0a
to
fa5d6a9
Compare
I have updated the commit description, let me know if that works well. |
You need to put it in the PR description or it won't appear on the commit after you merge this :) |
fa5d6a9
to
b216f58
Compare
…transfer_write` overwrites all elements of `insert_slice` Resolves llvm#101708 The updated logic now correctly checks if `transfer_write` completely overwrites `insert_slice` and only then applies the rewrite for this pattern. This check currently covers static sizes, for dynamic sizes value bounds analysis is needed (see `TODO:`).
b216f58
to
44fcb81
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 👍
Thanks for the approval, could you land this for me?! I will try and land a PR for dynamic shapes. I presume we need to conform to the value bounds interface for the ops? |
I'll land this tomorrow :) (just in case @hanhanW has any further comments) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I will land the PR.
…transfer_write` overwrites all elements of `insert_slice` (llvm#108803) Resolves llvm#101708 The updated logic now correctly checks if `transfer_write` completely overwrites `insert_slice` and only then applies the rewrite for this pattern. This check currently covers static sizes, for dynamic sizes value bounds analysis is needed (see `TODO:`).
Resolves #101708
The updated logic now correctly checks if
transfer_write
completelyoverwrites
insert_slice
and only then applies the rewrite for this pattern.This check currently covers static sizes, for dynamic sizes
value bounds analysis is needed (see
TODO:
).