-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases #145235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases #145235
Conversation
@llvm/pr-subscribers-mlir-linalg Author: Andrzej Warzyński (banach-space) ChangesThis patch adds additional checks to the hoisting logic to prevent Note: This may conservatively block some valid hoisting opportunities If this change prevents hoisting in cases that are provably safe, please Full diff: https://github.com/llvm/llvm-project/pull/145235.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 707b63ff9335b..808925a934979 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -303,7 +303,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
// 1. indices, vector type and permutation map are the same (i.e., the
// transfer_read/transfer_write ops are matching),
// 2. source operands for transfer.{read|write} do not originate from
- // Ops implementing ViewLikeOpInterface.
+ // nor have users that are Ops implementing ViewLikeOpInterface.
// 3. no other operations in the loop access the same memref except
// for transfer_read/transfer_write accessing statically disjoint
// slices.
@@ -312,14 +312,27 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
return WalkResult::advance();
+ // Check 2. for xfer_read
auto *source = transferRead.getBase().getDefiningOp();
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
return WalkResult::advance();
+ auto base = transferRead.getBase();
+ for (auto *user : base.getUsers())
+ if (isa_and_nonnull<ViewLikeOpInterface>(user))
+ return WalkResult::advance();
+
+ // Check 2. for xfer_wrire
source = transferWrite.getBase().getDefiningOp();
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
return WalkResult::advance();
+ base = transferWrite.getBase();
+ for (auto *user : base.getUsers())
+ if (isa_and_nonnull<ViewLikeOpInterface>(user))
+ return WalkResult::advance();
+
+ // Check 1. + 3.
// TODO: may want to memoize this information for performance but it
// likely gets invalidated often.
DominanceInfo dom(loop);
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 318edca73cce1..fd5a3edfb743f 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -1,5 +1,138 @@
// RUN: mlir-opt -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
+// The most basic example - hoisting is safe.
+
+// CHECK-LABEL: func.func @hoist_basic_vector_xfer_pair(
+// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) {
+func.func @hoist_basic_vector_xfer_pair(
+ %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
+// CHECK: %[[VAL_6:.*]] = "some_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK: scf.yield %[[VAL_6]] : vector<1xf32>
+// CHECK: }
+// CHECK: vector.transfer_write %[[SCF]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+ scf.for %i = %lb to %ub step %step {
+ %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+ %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write inside the loop.
+
+// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+ %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: vector.transfer_write %[[IN]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK: %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK: }
+
+ scf.for %i = %lb to %ub step %step {
+ vector.transfer_write %in, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+
+ %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+ %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write into _an alias_ of the %mem Op that is used by the
+// original xfer pair.
+
+// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+ %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [1, 1] [1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: vector.transfer_write %[[IN]], %[[SV]][%[[C0]], %[[C0]]] {{.*}} : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK: %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK: }
+
+ %sv = memref.subview %mem[0, 0][1, 1][1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+ scf.for %i = %lb to %ub step %step {
+ vector.transfer_write %in, %sv[%c0, %c0] : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+
+ %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+ %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func @hoist_vector_transfer_pairs(
// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
|
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesThis patch adds additional checks to the hoisting logic to prevent Note: This may conservatively block some valid hoisting opportunities If this change prevents hoisting in cases that are provably safe, please Full diff: https://github.com/llvm/llvm-project/pull/145235.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 707b63ff9335b..808925a934979 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -303,7 +303,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
// 1. indices, vector type and permutation map are the same (i.e., the
// transfer_read/transfer_write ops are matching),
// 2. source operands for transfer.{read|write} do not originate from
- // Ops implementing ViewLikeOpInterface.
+ // nor have users that are Ops implementing ViewLikeOpInterface.
// 3. no other operations in the loop access the same memref except
// for transfer_read/transfer_write accessing statically disjoint
// slices.
@@ -312,14 +312,27 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
return WalkResult::advance();
+ // Check 2. for xfer_read
auto *source = transferRead.getBase().getDefiningOp();
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
return WalkResult::advance();
+ auto base = transferRead.getBase();
+ for (auto *user : base.getUsers())
+ if (isa_and_nonnull<ViewLikeOpInterface>(user))
+ return WalkResult::advance();
+
+ // Check 2. for xfer_wrire
source = transferWrite.getBase().getDefiningOp();
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
return WalkResult::advance();
+ base = transferWrite.getBase();
+ for (auto *user : base.getUsers())
+ if (isa_and_nonnull<ViewLikeOpInterface>(user))
+ return WalkResult::advance();
+
+ // Check 1. + 3.
// TODO: may want to memoize this information for performance but it
// likely gets invalidated often.
DominanceInfo dom(loop);
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 318edca73cce1..fd5a3edfb743f 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -1,5 +1,138 @@
// RUN: mlir-opt -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
+// The most basic example - hoisting is safe.
+
+// CHECK-LABEL: func.func @hoist_basic_vector_xfer_pair(
+// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) {
+func.func @hoist_basic_vector_xfer_pair(
+ %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
+// CHECK: %[[VAL_6:.*]] = "some_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK: scf.yield %[[VAL_6]] : vector<1xf32>
+// CHECK: }
+// CHECK: vector.transfer_write %[[SCF]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+ scf.for %i = %lb to %ub step %step {
+ %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+ %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write inside the loop.
+
+// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
+ %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: vector.transfer_write %[[IN]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK: %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK: }
+
+ scf.for %i = %lb to %ub step %step {
+ vector.transfer_write %in, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+
+ %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+ %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Similar as the example above, but hoisting is no longer safe. That's due to
+// an extra xfer_write into _an alias_ of the %mem Op that is used by the
+// original xfer pair.
+
+// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
+func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
+ %mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [1, 1] [1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: vector.transfer_write %[[IN]], %[[SV]][%[[C0]], %[[C0]]] {{.*}} : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
+// CHECK: %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
+// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
+// CHECK: }
+
+ %sv = memref.subview %mem[0, 0][1, 1][1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
+ scf.for %i = %lb to %ub step %step {
+ vector.transfer_write %in, %sv[%c0, %c0] : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
+
+ %r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
+ %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_transfers %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: func @hoist_vector_transfer_pairs(
// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
|
auto *source = transferRead.getBase().getDefiningOp(); | ||
if (source && isa_and_nonnull<ViewLikeOpInterface>(source)) | ||
return WalkResult::advance(); | ||
|
||
auto base = transferRead.getBase(); | ||
for (auto *user : base.getUsers()) | ||
if (isa_and_nonnull<ViewLikeOpInterface>(user)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems better to check the confilction at
Hoisting.cpp:Line349 than cancel the optimization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using isDisjointTransferSet
only makes sense when the underlying base
is identical, see:
llvm-project/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Lines 318 to 319 in 577199f
if (transferA.getBase() != transferB.getBase()) | |
return false; |
Looking at the example from your repro, that would indeed apply to the pair that you want to hoist.
%assume_align = memref.assume_alignment %alloc, 64 : memref<4096x4096xf16>
scf.for %arg0 = %c256 to %c4096 step %c256 {
%0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
%1 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
%2 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %1, %0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
vector.transfer_write %2, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
}
However, it won't work if we modify the test a bit:
%assume_align = memref.assume_alignment %alloc, 64 : memref<4096x4096xf16>
scf.for %arg0 = %c256 to %c4096 step %c256 {
%0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
%1 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
%2 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %1, %0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
vector.transfer_write %2, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
}
In this case, the base for xfer_read
and xfer_write
Ops will be different (%assume_align
and %alloc
, respectively).
3db8207
to
36bbbdc
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
// `memref.assume_alignment`) | ||
if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) { | ||
Value memPreAlignment = assume.getMemref(); | ||
if (base.hasNUses(2) && memPreAlignment.hasOneUse()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should check the use num inside the loop, due to the base and memPreAlignment are usually used outside of the loop which should not affect the optimization.
NOTE: The 2nd commit implements:
I am happy to extract that into a separate PR if that makes more sense to you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for your effect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, it looks better!
// NOTE: We treat `memref.assume_alignment` as a special case: | ||
// 1. If it has exactly two uses then these have to be the xfer Ops | ||
// being looked at. | ||
// 2. Otherwise, there are other users that we should take into | ||
// account | ||
// In the case of 1., it is safe to look past AssumeAlignmentOp, | ||
// i.e. at the defining Op of the input MemRef, provided that: | ||
// * the original MemRef has only one use (i.e. | ||
// `memref.assume_alignment`) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: format it a bit, e.g., alignment, missing period at the end.
// NOTE: We treat `memref.assume_alignment` as a special case: | |
// 1. If it has exactly two uses then these have to be the xfer Ops | |
// being looked at. | |
// 2. Otherwise, there are other users that we should take into | |
// account | |
// In the case of 1., it is safe to look past AssumeAlignmentOp, | |
// i.e. at the defining Op of the input MemRef, provided that: | |
// * the original MemRef has only one use (i.e. | |
// `memref.assume_alignment`) | |
// NOTE: We treat `memref.assume_alignment` as a special case: | |
// 1. If it has exactly two uses then these have to be the xfer Ops | |
// being looked at. | |
// 2. Otherwise, there are other users that we should take into | |
// account. | |
// In the case of 1., it is safe to look past AssumeAlignmentOp, | |
// i.e. at the defining Op of the input MemRef, provided that: | |
// * the original MemRef has only one use (i.e. | |
// `memref.assume_alignment`). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
for (auto *user : base.getUsers()) | ||
if (isa_and_nonnull<ViewLikeOpInterface>(user)) | ||
return WalkResult::advance(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optional nit: I'd use any_of and llvm::IsaPred<>
in this case. (I think the user won't be NULL, so dropping nonnull should be okay.) E.g.,
if (llvm::any_of(base.getUsers(), llvm::IsaPred<ViewLikeOpInterface>))
return WalkResult::advance();
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, that's miles better, thanks!
…aliases This patch adds additional checks to the hoisting logic to prevent hoisting of `vector.transfer_read`/`vector.transfer_write` pairs when the underlying `memref` has users that introduce aliases via operations implementing `ViewLikeOpInterface`. Note: This may conservatively block some valid hoisting opportunities and could impact performance. However, as demonstrated by the included tests, the current behavior is too permissive and can lead to incorrect transformations. If this change prevents hoisting in cases that are provably safe, please share a minimal repro — I’d be happy to explore ways to relax the check.
…nce of aliases 1. Relax the conditions in the case of `memref.assume_alignment`. This unblocks hoisting for examples like the one reported in llvm#144825 (see the link in the top post: "detailed example pls refer to example"). 2. When checking the source operand for the xfer Ops, we only need to look at either xfer_write or xfer_read (we already know that the source is identical). The corresponding logic has been simplified.
…e presence of aliases Make sure only uses inside the loop are counted
…s in the presence of aliases Incorporate suggestions from HanHan, update the description.
…er pairs in the presence of aliases Extra test
7f58588
to
d4d95da
Compare
UPDATES (26/6/25):
If there are no new comments, I will land it before the end of the week. Thanks for reviews! |
This patch adds additional checks to the hoisting logic to prevent
hoisting of
vector.transfer_read
/vector.transfer_write
pairs whenthe underlying
memref
has users that introduce aliases via operationsimplementing
ViewLikeOpInterface
.Note: This may conservatively block some valid hoisting opportunities
and could impact performance. However, as demonstrated by the included
tests, the current behavior is too permissive and can lead to incorrect
transformations.
If this change prevents hoisting in cases that are provably safe, please
share a minimal repro — I’d be happy to explore ways to relax the check.