Skip to content

[mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases #145235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

banach-space
Copy link
Contributor

This patch adds additional checks to the hoisting logic to prevent
hoisting of vector.transfer_read/vector.transfer_write pairs when
the underlying memref has users that introduce aliases via operations
implementing ViewLikeOpInterface.

Note: This may conservatively block some valid hoisting opportunities
and could impact performance. However, as demonstrated by the included
tests, the current behavior is too permissive and can lead to incorrect
transformations.

If this change prevents hoisting in cases that are provably safe, please
share a minimal repro — I’d be happy to explore ways to relax the check.

@llvmbot
Copy link
Member

llvmbot commented Jun 22, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes

This patch adds additional checks to the hoisting logic to prevent
hoisting of vector.transfer_read/vector.transfer_write pairs when
the underlying memref has users that introduce aliases via operations
implementing ViewLikeOpInterface.

Note: This may conservatively block some valid hoisting opportunities
and could impact performance. However, as demonstrated by the included
tests, the current behavior is too permissive and can lead to incorrect
transformations.

If this change prevents hoisting in cases that are provably safe, please
share a minimal repro — I’d be happy to explore ways to relax the check.


Full diff: https://github.com/llvm/llvm-project/pull/145235.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp (+14-1)
  • (modified) mlir/test/Dialect/Linalg/hoisting.mlir (+133)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 707b63ff9335b..808925a934979 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -303,7 +303,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
       //   1. indices, vector type and permutation map are the same (i.e., the
       //      transfer_read/transfer_write ops are matching),
       //   2. source operands for transfer.{read|write} do not originate from
-      //      Ops implementing ViewLikeOpInterface.
+      //      nor have users that are Ops implementing ViewLikeOpInterface.
       //   3. no other operations in the loop access the same memref except
       //      for transfer_read/transfer_write accessing statically disjoint
       //      slices.
@@ -312,14 +312,27 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
           transferRead.getPermutationMap() != transferWrite.getPermutationMap())
         return WalkResult::advance();
 
+      // Check 2. for xfer_read
       auto *source = transferRead.getBase().getDefiningOp();
       if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
         return WalkResult::advance();
 
+      auto base = transferRead.getBase();
+      for (auto *user : base.getUsers())
+        if (isa_and_nonnull<ViewLikeOpInterface>(user))
+          return WalkResult::advance();
+
+      // Check 2. for xfer_wrire
       source = transferWrite.getBase().getDefiningOp();
       if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
         return WalkResult::advance();
 
+      base = transferWrite.getBase();
+      for (auto *user : base.getUsers())
+        if (isa_and_nonnull<ViewLikeOpInterface>(user))
+          return WalkResult::advance();
+
+      // Check 1. + 3.
       // TODO: may want to memoize this information for performance but it
       // likely gets invalidated often.
       DominanceInfo dom(loop);
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 318edca73cce1..fd5a3edfb743f 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -1,5 +1,138 @@
 // RUN: mlir-opt  -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
 
+// The most basic example - hoisting is safe.
+
+// CHECK-LABEL:   func.func @hoist_basic_vector_xfer_pair(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index) {
+func.func @hoist_basic_vector_xfer_pair(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:           %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
+// CHECK:             %[[VAL_6:.*]] = "some_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             scf.yield %[[VAL_6]] : vector<1xf32>
+// CHECK:           }
+// CHECK:           vector.transfer_write %[[SCF]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+  scf.for %i = %lb to %ub step %step {
+      %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write inside the loop.
+
+// CHECK-LABEL:   func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             vector.transfer_write %[[IN]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:             %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:             %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:           }
+
+  scf.for %i = %lb to %ub step %step {
+      vector.transfer_write %in, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+
+      %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write into _an alias_ of the %mem Op that is used by the
+// original xfer pair.
+
+// CHECK-LABEL:   func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [1, 1] [1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             vector.transfer_write %[[IN]], %[[SV]][%[[C0]], %[[C0]]] {{.*}} : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+// CHECK:             %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:             %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:           }
+
+  %sv = memref.subview %mem[0, 0][1, 1][1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+  scf.for %i = %lb to %ub step %step {
+      vector.transfer_write %in, %sv[%c0, %c0] : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+
+      %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func @hoist_vector_transfer_pairs(
 //  CHECK-SAME:   %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
 //  CHECK-SAME:   %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,

@llvmbot
Copy link
Member

llvmbot commented Jun 22, 2025

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

This patch adds additional checks to the hoisting logic to prevent
hoisting of vector.transfer_read/vector.transfer_write pairs when
the underlying memref has users that introduce aliases via operations
implementing ViewLikeOpInterface.

Note: This may conservatively block some valid hoisting opportunities
and could impact performance. However, as demonstrated by the included
tests, the current behavior is too permissive and can lead to incorrect
transformations.

If this change prevents hoisting in cases that are provably safe, please
share a minimal repro — I’d be happy to explore ways to relax the check.


Full diff: https://github.com/llvm/llvm-project/pull/145235.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp (+14-1)
  • (modified) mlir/test/Dialect/Linalg/hoisting.mlir (+133)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 707b63ff9335b..808925a934979 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -303,7 +303,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
       //   1. indices, vector type and permutation map are the same (i.e., the
       //      transfer_read/transfer_write ops are matching),
       //   2. source operands for transfer.{read|write} do not originate from
-      //      Ops implementing ViewLikeOpInterface.
+      //      nor have users that are Ops implementing ViewLikeOpInterface.
       //   3. no other operations in the loop access the same memref except
       //      for transfer_read/transfer_write accessing statically disjoint
       //      slices.
@@ -312,14 +312,27 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
           transferRead.getPermutationMap() != transferWrite.getPermutationMap())
         return WalkResult::advance();
 
+      // Check 2. for xfer_read
       auto *source = transferRead.getBase().getDefiningOp();
       if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
         return WalkResult::advance();
 
+      auto base = transferRead.getBase();
+      for (auto *user : base.getUsers())
+        if (isa_and_nonnull<ViewLikeOpInterface>(user))
+          return WalkResult::advance();
+
+      // Check 2. for xfer_wrire
       source = transferWrite.getBase().getDefiningOp();
       if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
         return WalkResult::advance();
 
+      base = transferWrite.getBase();
+      for (auto *user : base.getUsers())
+        if (isa_and_nonnull<ViewLikeOpInterface>(user))
+          return WalkResult::advance();
+
+      // Check 1. + 3.
       // TODO: may want to memoize this information for performance but it
       // likely gets invalidated often.
       DominanceInfo dom(loop);
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 318edca73cce1..fd5a3edfb743f 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -1,5 +1,138 @@
 // RUN: mlir-opt  -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
 
+// The most basic example - hoisting is safe.
+
+// CHECK-LABEL:   func.func @hoist_basic_vector_xfer_pair(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index) {
+func.func @hoist_basic_vector_xfer_pair(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:           %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
+// CHECK:             %[[VAL_6:.*]] = "some_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             scf.yield %[[VAL_6]] : vector<1xf32>
+// CHECK:           }
+// CHECK:           vector.transfer_write %[[SCF]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+  scf.for %i = %lb to %ub step %step {
+      %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write inside the loop.
+
+// CHECK-LABEL:   func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             vector.transfer_write %[[IN]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:             %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:             %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:           }
+
+  scf.for %i = %lb to %ub step %step {
+      vector.transfer_write %in, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+
+      %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write into _an alias_ of the %mem Op that is used by the
+// original xfer pair.
+
+// CHECK-LABEL:   func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+// CHECK-SAME:      %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:      %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+    %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [1, 1] [1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+// CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK:             vector.transfer_write %[[IN]], %[[SV]][%[[C0]], %[[C0]]] {{.*}} : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+// CHECK:             %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK:             %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK:             vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK:           }
+
+  %sv = memref.subview %mem[0, 0][1, 1][1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+  scf.for %i = %lb to %ub step %step {
+      vector.transfer_write %in, %sv[%c0, %c0] : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+
+      %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.hoist_redundant_vector_transfers %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func @hoist_vector_transfer_pairs(
 //  CHECK-SAME:   %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
 //  CHECK-SAME:   %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,

auto *source = transferRead.getBase().getDefiningOp();
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
return WalkResult::advance();

auto base = transferRead.getBase();
for (auto *user : base.getUsers())
if (isa_and_nonnull<ViewLikeOpInterface>(user))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems better to check the confilction at
Hoisting.cpp:Line349 than cancel the optimization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using isDisjointTransferSetonly makes sense when the underlying base is identical, see:

if (transferA.getBase() != transferB.getBase())
return false;

Looking at the example from your repro, that would indeed apply to the pair that you want to hoist.

    %assume_align = memref.assume_alignment %alloc, 64 : memref<4096x4096xf16>
    scf.for %arg0 = %c256 to %c4096 step %c256 {
      %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
      %1 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
      %2 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %1, %0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      vector.transfer_write %2, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
    }

However, it won't work if we modify the test a bit:

    %assume_align = memref.assume_alignment %alloc, 64 : memref<4096x4096xf16>
    scf.for %arg0 = %c256 to %c4096 step %c256 {
      %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
      %1 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
      %2 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %1, %0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      vector.transfer_write %2, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
    }

In this case, the base for xfer_read and xfer_write Ops will be different (%assume_align and %alloc, respectively).

@banach-space banach-space force-pushed the andrzej/hoisting/make_aliasing_checks_stricter branch from 3db8207 to 36bbbdc Compare June 25, 2025 09:25
Copy link

github-actions bot commented Jun 25, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

// `memref.assume_alignment`)
if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
Value memPreAlignment = assume.getMemref();
if (base.hasNUses(2) && memPreAlignment.hasOneUse())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check the use num inside the loop, due to the base and memPreAlignment are usually used outside of the loop which should not affect the optimization.

@banach-space
Copy link
Contributor Author

NOTE: The 2nd commit implements:

I am happy to extract that into a separate PR if that makes more sense to you.

Copy link
Contributor

@xiangzh1 xiangzh1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for your effect.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it looks better!

Comment on lines 322 to 330
// NOTE: We treat `memref.assume_alignment` as a special case:
// 1. If it has exactly two uses then these have to be the xfer Ops
// being looked at.
// 2. Otherwise, there are other users that we should take into
// account
// In the case of 1., it is safe to look past AssumeAlignmentOp,
// i.e. at the defining Op of the input MemRef, provided that:
// * the original MemRef has only one use (i.e.
// `memref.assume_alignment`)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: format it a bit, e.g., alignment, missing period at the end.

Suggested change
// NOTE: We treat `memref.assume_alignment` as a special case:
// 1. If it has exactly two uses then these have to be the xfer Ops
// being looked at.
// 2. Otherwise, there are other users that we should take into
// account
// In the case of 1., it is safe to look past AssumeAlignmentOp,
// i.e. at the defining Op of the input MemRef, provided that:
// * the original MemRef has only one use (i.e.
// `memref.assume_alignment`)
// NOTE: We treat `memref.assume_alignment` as a special case:
// 1. If it has exactly two uses then these have to be the xfer Ops
// being looked at.
// 2. Otherwise, there are other users that we should take into
// account.
// In the case of 1., it is safe to look past AssumeAlignmentOp,
// i.e. at the defining Op of the input MemRef, provided that:
// * the original MemRef has only one use (i.e.
// `memref.assume_alignment`).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Comment on lines 345 to 347
for (auto *user : base.getUsers())
if (isa_and_nonnull<ViewLikeOpInterface>(user))
return WalkResult::advance();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional nit: I'd use any_of and llvm::IsaPred<> in this case. (I think the user won't be NULL, so dropping nonnull should be okay.) E.g.,

if (llvm::any_of(base.getUsers(), llvm::IsaPred<ViewLikeOpInterface>))
  return WalkResult::advance();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, that's miles better, thanks!

…aliases

This patch adds additional checks to the hoisting logic to prevent
hoisting of `vector.transfer_read`/`vector.transfer_write` pairs when
the underlying `memref` has users that introduce aliases via operations
implementing `ViewLikeOpInterface`.

Note: This may conservatively block some valid hoisting opportunities
and could impact performance. However, as demonstrated by the included
tests, the current behavior is too permissive and can lead to incorrect
transformations.

If this change prevents hoisting in cases that are provably safe, please
share a minimal repro — I’d be happy to explore ways to relax the check.
…nce of aliases

1. Relax the conditions in the case of `memref.assume_alignment`. This
   unblocks hoisting for examples like the one reported in
   llvm#144825 (see the link in
   the top post: "detailed example pls refer to example").
2. When checking the source operand for the xfer Ops, we only need to
   look at either xfer_write or xfer_read (we already know that the
   source is identical). The corresponding logic has been simplified.
…e presence of aliases

Make sure only uses inside the loop are counted
…s in the presence of aliases

Incorporate suggestions from HanHan, update the description.
…er pairs in the presence of aliases

Extra test
@banach-space banach-space force-pushed the andrzej/hoisting/make_aliasing_checks_stricter branch from 7f58588 to d4d95da Compare June 26, 2025 10:52
@banach-space
Copy link
Contributor Author

UPDATES (26/6/25):

If there are no new comments, I will land it before the end of the week. Thanks for reviews!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants