-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Teach TransferOptimization
to look through trivial aliases
#87805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesThis allows A trivial aliases is (currently) defined as:
Full diff: https://github.com/llvm/llvm-project/pull/87805.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 0ffef6aabccc18..87a03e2f874761 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -88,6 +88,46 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
return false;
}
+/// Walk up the source chain until an operation that changes/defines the view of
+/// memory is found (i.e. skip operations that alias the entire view).
+Value skipFullyAliasingOperations(Value source) {
+ while (auto op = source.getDefiningOp()) {
+ if (auto subViewOp = dyn_cast<memref::SubViewOp>(op);
+ subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
+ // A `memref.subview` with an all zero offset, and all unit strides, still
+ // points to the same memory.
+ source = subViewOp.getSource();
+ } else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
+ // A `memref.cast` still points to the same memory.
+ source = castOp.getSource();
+ } else {
+ return source;
+ }
+ }
+ return source;
+}
+
+/// Checks if two (memref) values are are the same, or are statically known to
+/// alias the same region of memory.
+bool isSameViewOrTrivialAlias(Value a, Value b) {
+ return skipFullyAliasingOperations(a) == skipFullyAliasingOperations(b);
+}
+
+/// Walk up the source chain until something an op other than a `memref.subview`
+/// or `memref.cast` is found.
+Value skipSubViewsAndCasts(Value source) {
+ while (auto op = source.getDefiningOp()) {
+ if (auto subView = dyn_cast<memref::SubViewOp>(op)) {
+ source = subView.getSource();
+ } else if (auto cast = dyn_cast<memref::CastOp>(op)) {
+ source = cast.getSource();
+ } else {
+ return source;
+ }
+ }
+ return source;
+}
+
/// For transfer_write to overwrite fully another transfer_write must:
/// 1. Access the same memref with the same indices and vector type.
/// 2. Post-dominate the other transfer_write operation.
@@ -104,10 +144,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
<< "\n");
llvm::SmallVector<Operation *, 8> blockingAccesses;
Operation *firstOverwriteCandidate = nullptr;
- Value source = write.getSource();
- // Skip subview ops.
- while (auto subView = source.getDefiningOp<memref::SubViewOp>())
- source = subView.getSource();
+ Value source = skipSubViewsAndCasts(write.getSource());
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
source.getUsers().end());
llvm::SmallDenseSet<Operation *, 32> processed;
@@ -116,8 +153,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
// If the user has already been processed skip.
if (!processed.insert(user).second)
continue;
- if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
- users.append(subView->getUsers().begin(), subView->getUsers().end());
+ if (isa<memref::SubViewOp, memref::CastOp>(user)) {
+ users.append(user->getUsers().begin(), user->getUsers().end());
continue;
}
if (isMemoryEffectFree(user))
@@ -126,7 +163,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
continue;
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
// Check candidate that can override the store.
- if (write.getSource() == nextWrite.getSource() &&
+ if (isSameViewOrTrivialAlias(nextWrite.getSource(), write.getSource()) &&
checkSameValueWAW(nextWrite, write) &&
postDominators.postDominates(nextWrite, write)) {
if (firstOverwriteCandidate == nullptr ||
@@ -191,10 +228,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
<< "\n");
SmallVector<Operation *, 8> blockingWrites;
vector::TransferWriteOp lastwrite = nullptr;
- Value source = read.getSource();
- // Skip subview ops.
- while (auto subView = source.getDefiningOp<memref::SubViewOp>())
- source = subView.getSource();
+ Value source = skipSubViewsAndCasts(read.getSource());
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
source.getUsers().end());
llvm::SmallDenseSet<Operation *, 32> processed;
@@ -203,12 +237,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
// If the user has already been processed skip.
if (!processed.insert(user).second)
continue;
- if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
- users.append(subView->getUsers().begin(), subView->getUsers().end());
- continue;
- }
- if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) {
- users.append(collapsed->getUsers().begin(), collapsed->getUsers().end());
+ if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::CastOp>(user)) {
+ users.append(user->getUsers().begin(), user->getUsers().end());
continue;
}
if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
@@ -221,7 +251,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
cast<VectorTransferOpInterface>(read.getOperation()),
/*testDynamicValueUsingBounds=*/true))
continue;
- if (write.getSource() == read.getSource() &&
+ if (isSameViewOrTrivialAlias(read.getSource(), write.getSource()) &&
dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
lastwrite = write;
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 13957af014b89e..e47d26940afa2f 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -222,7 +222,7 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
// `vector.transfer_write` would not be safe:
// %1 = vector.transfer_read %subview
// vector.transfer_write %1, %alloca
-// vector.transfer_write %vec, %collapse_shape
+// vector.transfer_write %vec, %collapse_shape
// %2 = vector.transfer_read %alloca
// vector.transfer_write %1, %subview
// Indeed, %alloca and %collapse_shape alias and hence %2 != %1. Instead, the
@@ -360,3 +360,33 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
return
}
+
+// Here each read/write is to a different subview, but they all point to exact
+// same bit of memory (just through casts and subviews with unit strides and
+// zero offsets).
+// CHECK-LABEL: func @forward_and_eliminate_stores_through_trivial_aliases
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: vector.transfer_write
+// CHECK: return
+func.func @forward_and_eliminate_stores_through_trivial_aliases(
+ %buffer : memref<?x?xf32>, %vec: vector<[8]x[8]xf32>, %size: index, %a_size: index, %another_size: index
+) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c32 = arith.constant 32 : index
+ %cst = arith.constant 0.0 : f32
+ vector.transfer_write %vec, %buffer[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+ %direct_subview = memref.subview %buffer[0, 0] [%a_size, %a_size] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+ %cast = memref.cast %direct_subview : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32>
+ %subview_of_cast = memref.subview %cast[0, 0] [%another_size, %another_size] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+ %21 = vector.transfer_read %direct_subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, strided<[?, 1], offset: ?>>, vector<[8]x[8]xf32>
+ %23 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %21) -> (vector<[8]x[8]xf32>) {
+ %24 = arith.addf %arg3, %arg3 : vector<[8]x[8]xf32>
+ scf.yield %24 : vector<[8]x[8]xf32>
+ }
+ vector.transfer_write %23, %subview_of_cast[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>>
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea looks good to me, but I wonder if running FoldMemRefAliasOps would help your case? I think they hopefully will all be folded into the same memref.subview ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I think extending the transformation to deal with these simple alias cases make it more powerful. We know there are cases where some of these casts/subviews can't be folded.
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also LGTM cheers
407f642
to
93b80d7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
This allows `TransferOptimization` to eliminate and forward stores that are to trivial aliases (rather than just to identical memref values). A trivial aliases is (currently) defined as: 1. A `memref.cast` 2. A `memref.subview` with a zero offset and unit strides 3. A chain of 1 and 2
I believe @matthias-springer has added related functionality a few weeks back? |
This allows
TransferOptimization
to eliminate and forward stores that are to trivial aliases (rather than just to identical memref values).A trivial aliases is (currently) defined as:
memref.cast
memref.subview
with a zero offset and unit strides