Skip to content

[mlir][vector] Teach TransferOptimization to look through trivial aliases #87805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 16, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Apr 5, 2024

This allows TransferOptimization to eliminate and forward stores that are to trivial aliases (rather than just to identical memref values).

A trivial aliases is (currently) defined as:

  1. A memref.cast
  2. A memref.subview with a zero offset and unit strides
  3. A chain of 1 and 2

@llvmbot
Copy link
Member

llvmbot commented Apr 5, 2024

@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

This allows TransferOptimization to eliminate and forward stores that are to trivial aliases (rather than just to identical memref values).

A trivial aliases is (currently) defined as:

  1. A memref.cast
  2. A memref.subview with a zero offset and unit strides
  3. A chain of 1 and 2

Full diff: https://github.com/llvm/llvm-project/pull/87805.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+48-18)
  • (modified) mlir/test/Dialect/Vector/vector-transferop-opt.mlir (+31-1)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 0ffef6aabccc18..87a03e2f874761 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -88,6 +88,46 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
   return false;
 }
 
+/// Walk up the source chain until an operation that changes/defines the view of
+/// memory is found (i.e. skip operations that alias the entire view).
+Value skipFullyAliasingOperations(Value source) {
+  while (auto op = source.getDefiningOp()) {
+    if (auto subViewOp = dyn_cast<memref::SubViewOp>(op);
+        subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
+      // A `memref.subview` with an all zero offset, and all unit strides, still
+      // points to the same memory.
+      source = subViewOp.getSource();
+    } else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
+      // A `memref.cast` still points to the same memory.
+      source = castOp.getSource();
+    } else {
+      return source;
+    }
+  }
+  return source;
+}
+
+/// Checks if two (memref) values are are the same, or are statically known to
+/// alias the same region of memory.
+bool isSameViewOrTrivialAlias(Value a, Value b) {
+  return skipFullyAliasingOperations(a) == skipFullyAliasingOperations(b);
+}
+
+/// Walk up the source chain until something an op other than a `memref.subview`
+/// or `memref.cast` is found.
+Value skipSubViewsAndCasts(Value source) {
+  while (auto op = source.getDefiningOp()) {
+    if (auto subView = dyn_cast<memref::SubViewOp>(op)) {
+      source = subView.getSource();
+    } else if (auto cast = dyn_cast<memref::CastOp>(op)) {
+      source = cast.getSource();
+    } else {
+      return source;
+    }
+  }
+  return source;
+}
+
 /// For transfer_write to overwrite fully another transfer_write must:
 /// 1. Access the same memref with the same indices and vector type.
 /// 2. Post-dominate the other transfer_write operation.
@@ -104,10 +144,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
                     << "\n");
   llvm::SmallVector<Operation *, 8> blockingAccesses;
   Operation *firstOverwriteCandidate = nullptr;
-  Value source = write.getSource();
-  // Skip subview ops.
-  while (auto subView = source.getDefiningOp<memref::SubViewOp>())
-    source = subView.getSource();
+  Value source = skipSubViewsAndCasts(write.getSource());
   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
                                            source.getUsers().end());
   llvm::SmallDenseSet<Operation *, 32> processed;
@@ -116,8 +153,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
     // If the user has already been processed skip.
     if (!processed.insert(user).second)
       continue;
-    if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
-      users.append(subView->getUsers().begin(), subView->getUsers().end());
+    if (isa<memref::SubViewOp, memref::CastOp>(user)) {
+      users.append(user->getUsers().begin(), user->getUsers().end());
       continue;
     }
     if (isMemoryEffectFree(user))
@@ -126,7 +163,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
       continue;
     if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
       // Check candidate that can override the store.
-      if (write.getSource() == nextWrite.getSource() &&
+      if (isSameViewOrTrivialAlias(nextWrite.getSource(), write.getSource()) &&
           checkSameValueWAW(nextWrite, write) &&
           postDominators.postDominates(nextWrite, write)) {
         if (firstOverwriteCandidate == nullptr ||
@@ -191,10 +228,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
                     << "\n");
   SmallVector<Operation *, 8> blockingWrites;
   vector::TransferWriteOp lastwrite = nullptr;
-  Value source = read.getSource();
-  // Skip subview ops.
-  while (auto subView = source.getDefiningOp<memref::SubViewOp>())
-    source = subView.getSource();
+  Value source = skipSubViewsAndCasts(read.getSource());
   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
                                            source.getUsers().end());
   llvm::SmallDenseSet<Operation *, 32> processed;
@@ -203,12 +237,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
     // If the user has already been processed skip.
     if (!processed.insert(user).second)
       continue;
-    if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
-      users.append(subView->getUsers().begin(), subView->getUsers().end());
-      continue;
-    }
-    if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) {
-      users.append(collapsed->getUsers().begin(), collapsed->getUsers().end());
+    if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::CastOp>(user)) {
+      users.append(user->getUsers().begin(), user->getUsers().end());
       continue;
     }
     if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
@@ -221,7 +251,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
               cast<VectorTransferOpInterface>(read.getOperation()),
               /*testDynamicValueUsingBounds=*/true))
         continue;
-      if (write.getSource() == read.getSource() &&
+      if (isSameViewOrTrivialAlias(read.getSource(), write.getSource()) &&
           dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
         if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
           lastwrite = write;
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 13957af014b89e..e47d26940afa2f 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -222,7 +222,7 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
 // `vector.transfer_write` would not be safe:
 //         %1 = vector.transfer_read %subview
 //         vector.transfer_write %1, %alloca
-//         vector.transfer_write %vec, %collapse_shape 
+//         vector.transfer_write %vec, %collapse_shape
 //         %2 = vector.transfer_read %alloca
 //         vector.transfer_write %1, %subview
 // Indeed, %alloca and %collapse_shape alias and hence %2 != %1. Instead, the
@@ -360,3 +360,33 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
   vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
   return
 }
+
+// Here each read/write is to a different subview, but they all point to exact
+// same bit of memory (just through casts and subviews with unit strides and
+// zero offsets).
+// CHECK-LABEL: func @forward_and_eliminate_stores_through_trivial_aliases
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_and_eliminate_stores_through_trivial_aliases(
+  %buffer : memref<?x?xf32>, %vec: vector<[8]x[8]xf32>, %size: index, %a_size: index, %another_size: index
+) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c32 = arith.constant 32 : index
+  %cst = arith.constant 0.0 : f32
+  vector.transfer_write %vec, %buffer[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %direct_subview = memref.subview %buffer[0, 0] [%a_size, %a_size] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+  %cast = memref.cast %direct_subview : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32>
+  %subview_of_cast = memref.subview %cast[0, 0] [%another_size, %another_size] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+  %21 = vector.transfer_read %direct_subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, strided<[?, 1], offset: ?>>, vector<[8]x[8]xf32>
+  %23 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %21) -> (vector<[8]x[8]xf32>) {
+    %24 = arith.addf %arg3, %arg3 : vector<[8]x[8]xf32>
+    scf.yield %24 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %23, %subview_of_cast[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>>
+  return
+}

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea looks good to me, but I wonder if running FoldMemRefAliasOps would help your case? I think they hopefully will all be folded into the same memref.subview ops.

https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I think extending the transformation to deal with these simple alias cases make it more powerful. We know there are cases where some of these casts/subviews can't be folded.

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also LGTM cheers

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

MacDue added 3 commits May 16, 2024 08:57
This allows `TransferOptimization` to eliminate and forward stores that
are to trivial aliases (rather than just to identical memref values).

A trivial aliases is (currently) defined as:

  1. A `memref.cast`
  2. A `memref.subview` with a zero offset and unit strides
  3. A chain of 1 and 2
@MacDue MacDue force-pushed the more_transfer_opt branch from d3aecb2 to 1d70bac Compare May 16, 2024 09:05
@MacDue MacDue merged commit 90d2f8c into llvm:main May 16, 2024
3 of 4 checks passed
@MacDue MacDue deleted the more_transfer_opt branch May 16, 2024 09:53
@nicolasvasilache
Copy link
Contributor

I believe @matthias-springer has added related functionality a few weeks back?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants