Skip to content

[mlir][hoisting] Support memref.assume_alignment in linalg hoisting #144843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

xiangzh1
Copy link
Contributor

The recent updates of AssumeAlignmentOp will affect linalg hoisting optimization.
We find it has regression on "hoist load/store out of loop".
The flowing issue list more detail:

related issue : 144825

This patch tend to fix this problem due to the assume_alignment just mark memref's alignment,
the linalg hoisting should check its memref operand not it self.

xiangzh1 added 2 commits June 18, 2025 17:10
All ViewLike operations are excluded by hoisting optimization. But
assume_alignment just mark memref's alignment, we should check its
memref instead of itself.
@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: XiangZhang (xiangzh1)

Changes

The recent updates of AssumeAlignmentOp will affect linalg hoisting optimization.
We find it has regression on "hoist load/store out of loop".
The flowing issue list more detail:

related issue : 144825

This patch tend to fix this problem due to the assume_alignment just mark memref's alignment,
the linalg hoisting should check its memref operand not it self.


Full diff: https://github.com/llvm/llvm-project/pull/144843.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp (+21-5)
  • (modified) mlir/test/Dialect/Linalg/hoisting.mlir (+52)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 707b63ff9335b..b949b06631484 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -199,6 +199,24 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
   return true;
 }
 
+static bool skipViewLike(Operation *source0, Operation *source1) {
+  bool viewLikeCheck = true;
+  auto assumeAlignOp = dyn_cast_or_null<memref::AssumeAlignmentOp>(source0);
+  if (assumeAlignOp && source0 == source1) {
+    Value sourceMemRef = assumeAlignOp.getMemref();
+    Operation *sourceOp = sourceMemRef.getDefiningOp();
+    return isa_and_nonnull<ViewLikeOpInterface>(sourceOp);
+  }
+
+  if (source0 && isa_and_nonnull<ViewLikeOpInterface>(source0))
+    return true;
+
+  if (source1 && isa_and_nonnull<ViewLikeOpInterface>(source1))
+    return true;
+
+  return false;
+}
+
 void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
                                                  bool verifyNonZeroTrip) {
   bool changed = true;
@@ -312,12 +330,10 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
           transferRead.getPermutationMap() != transferWrite.getPermutationMap())
         return WalkResult::advance();
 
-      auto *source = transferRead.getBase().getDefiningOp();
-      if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
-        return WalkResult::advance();
+      auto *source0 = transferRead.getBase().getDefiningOp();
+      auto *source1 = transferWrite.getBase().getDefiningOp();
 
-      source = transferWrite.getBase().getDefiningOp();
-      if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
+      if (skipViewLike(source0, source1))
         return WalkResult::advance();
 
       // TODO: may want to memoize this information for performance but it
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 318edca73cce1..c58074e40c5f4 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -802,3 +802,55 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// Test hoisting of vector.transfer_read/transfer_write pairs with same location
+// and this location is marked with assume_align.
+
+// CHECK-LABEL:  func.func @hoist_vector_transfer_read_write() {
+// CHECK:          %c0 = arith.constant 0 : index
+// CHECK-NEXT:     %c256 = arith.constant 256 : index
+// CHECK-NEXT:     %c4096 = arith.constant 4096 : index
+// CHECK-NEXT:     %cst = arith.constant 0.000000e+00 : f16
+// CHECK-NEXT:     %alloc = memref.alloc() : memref<4096x4096xf16>
+// CHECK-NEXT:     %alloc_0 = memref.alloc() : memref<4096x4096xf16>
+// CHECK-NEXT:     %assume_align = memref.assume_alignment %alloc, 64 : memref<4096x4096xf16>
+// CHECK-NEXT:     %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+// CHECK-NEXT:     %1 = scf.for %arg0 = %c256 to %c4096 step %c256 iter_args(%arg1 = %0) -> (vector<16x16xf16>) {
+// CHECK-NEXT:       %2 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+// CHECK-NEXT:       %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %2, %arg1 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+// CHECK-NEXT:       scf.yield %3 : vector<16x16xf16>
+// CHECK-NEXT:     }
+// CHECK-NEXT:     vector.transfer_write %1, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
+// CHECK-NEXT:     return
+// CHECK-NEXT:   }
+
+func.func @hoist_vector_transfer_read_write() {
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c256 = arith.constant 256 : index
+  %c4096 = arith.constant 4096 : index
+  %cst_0 = arith.constant 0.000000e+00 : f16
+  %m0 = memref.alloc() : memref<4096x4096xf16>
+  %m1 = memref.alloc() : memref<4096x4096xf16>
+  %assume_align_0 = memref.assume_alignment %m0, 64 : memref<4096x4096xf16>
+  %assume_align_1 = memref.assume_alignment %m1, 64 : memref<4096x4096xf16>
+  scf.for %arg0 = %c256 to %c4096 step %c256 {
+    %1 = vector.transfer_read %assume_align_0[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+    %2 = vector.transfer_read %m1[%arg0, %arg0], %cst_0 {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
+    %3 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %2, %1 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+    vector.transfer_write %3, %assume_align_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+        transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

@xiangzh1
Copy link
Contributor Author

ping, thanks!

@banach-space
Copy link
Contributor

banach-space commented Jun 22, 2025

Please see the related discussion here: #144809

I will post more comments soon.

@xiangzh1
Copy link
Contributor Author

xiangzh1 commented Jun 23, 2025

Please see the related discussion here: #144809

I will post more comments soon.

Yes,they are same problem, let's discuss in #144809.
By the way,do you know which tool can auto-gen the test CHECK in mlir/test/Dialect/Linalg/hoisting.mlir ?

@banach-space
Copy link
Contributor

By the way,do you know which tool can auto-gen the test CHECK in mlir/test/Dialect/Linalg/hoisting.mlir ?

Have you tried generate-test-checks.py?

@xiangzh1
Copy link
Contributor Author

By the way,do you know which tool can auto-gen the test CHECK in mlir/test/Dialect/Linalg/hoisting.mlir ?

Have you tried generate-test-checks.py?

not before, never use it ^_^
~/LLVM/llvm-project$ ./mlir/utils/generate-test-checks.py mlir/test/Dialect/Linalg/hoisting.mlir
it works now,( just not like update_llc_test_checks.py , generate-test-checks.py didn't change the test itself, it just print to standard output。)
thanks a lot !

@xiangzh1
Copy link
Contributor Author

After discuss in 144809 ,we prefer banach-space‘s 145235 to fix it.

@xiangzh1 xiangzh1 closed this Jun 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue https://github.com/llvm/llvm-project/contribute mlir:linalg mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants