Skip to content

Commit 94c0477

Browse files
authored
[mlir][vector] Prevent incorrect vector.transfer_{read|write} hoisting (#66930)
At the moment, `hoistRedundantVectorTransfers` would hoist the `vector.transfer_read`/`vector.transfer_write` pair in this function: ```mlir func.func @no_hoisting_write_to_memref(%rhs: i32, %arg1: vector<1xi32>) { %c0_i32 = arith.constant 0 : i32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c20 = arith.constant 20 : index %alloca = memref.alloca() {alignment = 64 : i64} : memref<1x1x2xi32> %cast = memref.cast %alloca : memref<1x1x2xi32> to memref<1x1x2xi32> %collapsed_1 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32> scf.for %_ = %c0 to %c20 step %c4 { %collapsed_2 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32> %lhs = vector.transfer_read %collapsed_1[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32> %acc = vector.transfer_read %collapsed_2[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32> %op = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<1xi32>, i32 vector.transfer_write %op, %collapsed_1[%c0] {in_bounds = [true]} : vector<1xi32>, memref<2xi32> } return } ``` as follows: ```mlir func.func @no_hoisting_write_to_memref(%arg0: i32, %arg1: vector<1xi32>) { %c0_i32 = arith.constant 0 : i32 %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %c20 = arith.constant 20 : index %alloca = memref.alloca() {alignment = 64 : i64} : memref<1x1x2xi32> %collapse_shape = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32> %collapse_shape_0 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32> %0 = vector.transfer_read %collapse_shape[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32> %1 = vector.transfer_read %collapse_shape_0[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32> %2 = scf.for %arg2 = %c0 to %c20 step %c4 iter_args(%arg3 = %0) -> (vector<1xi32>) { %3 = vector.outerproduct %arg3, %arg0, %1 {kind = #vector.kind<add>} : vector<1xi32>, i32 scf.yield %3 : vector<1xi32> } vector.transfer_write %2, %collapse_shape[%c0] {in_bounds = [true]} : vector<1xi32>, memref<2xi32> return } ``` This is not safe. While one argument for `vector.outerproduct` (`%rhs` from the original loop) is correctly being forwarded via `iter_args`, the other one (`%acc` from the original loop) is not. This patch disables hoisting in cases where the source of "candidate" `vector.transfer_read` aliases with some other `memref`. A more generic approach would be to make sure that all values are correctly forwarded via `iter_args`, but that would require involving alias analysis. [1] Based on iree-org/iree#14994.
1 parent 5c9e90f commit 94c0477

File tree

3 files changed

+76
-16
lines changed

3 files changed

+76
-16
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,19 @@ namespace linalg {
2828
/// 3. No uses of the memref either dominate the transfer_read or are
2929
/// dominated by the transfer_write (i.e. no aliasing between the write and
3030
/// the read across the loop)
31+
/// 4. The source operands for vector.transfer_{read|write} do not originate
32+
/// from Ops implementing ViewLikeOpInterface (to reduce the risk of
33+
/// aliasing).
3134
/// To improve hoisting opportunities, call the `moveLoopInvariantCode` helper
3235
/// function on the candidate loop above which to hoist. Hoisting the transfers
3336
/// results in scf::ForOp yielding the value that originally transited through
3437
/// memory.
3538
///
39+
/// TODO: To further improve hoisting opportunities, fold aliasing memref
40+
/// operations into respective vector.transfer{read|write} operations and
41+
/// avoid using ops implementing ViewLikeOpInterface as the source for transfer
42+
/// Ops.
43+
///
3644
/// WARNING: This hoisting does not model parallelism and is generally incorrect
3745
/// when used on distributed loops with memref semantics!
3846
void hoistRedundantVectorTransfers(func::FuncOp func);

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,12 @@ void mlir::linalg::hoistRedundantVectorTransfersOnTensor(func::FuncOp func) {
5555
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
5656
LoopLikeOpInterface loop) {
5757
Value source = transferRead.getSource();
58-
// Skip subview and collapse_shape Ops
59-
while (auto subView = source.getDefiningOp<memref::SubViewOp>())
60-
source = subView.getSource();
61-
while (auto collapsed = source.getDefiningOp<memref::CollapseShapeOp>())
62-
source = collapsed->getOperand(0);
58+
59+
// Skip view-like Ops and retrive the actual soruce Operation
60+
while (auto srcOp =
61+
dyn_cast_or_null<ViewLikeOpInterface>(source.getDefiningOp()))
62+
source = srcOp.getViewSource();
63+
6364
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
6465
source.getUsers().end());
6566
llvm::SmallDenseSet<Operation *, 32> processed;
@@ -68,12 +69,8 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
6869
// If the user has already been processed skip.
6970
if (!processed.insert(user).second)
7071
continue;
71-
if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
72-
users.append(subView->getUsers().begin(), subView->getUsers().end());
73-
continue;
74-
}
75-
if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) {
76-
users.append(collapsed->getUsers().begin(), collapsed->getUsers().end());
72+
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
73+
users.append(viewLike->getUsers().begin(), viewLike->getUsers().end());
7774
continue;
7875
}
7976
if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
@@ -144,14 +141,24 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
144141
// Approximate aliasing by checking that:
145142
// 1. indices, vector type and permutation map are the same (i.e., the
146143
// transfer_read/transfer_write ops are matching),
147-
// 2. no other operations in the loop access the same memref except
144+
// 2. source operands for transfer.{read|write} do not originate from
145+
// Ops implementing ViewLikeOpInterface.
146+
// 3. no other operations in the loop access the same memref except
148147
// for transfer_read/transfer_write accessing statically disjoint
149148
// slices.
150149
if (transferRead.getIndices() != transferWrite.getIndices() ||
151150
transferRead.getVectorType() != transferWrite.getVectorType() ||
152151
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
153152
return WalkResult::advance();
154153

154+
auto *source = transferRead.getSource().getDefiningOp();
155+
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
156+
return WalkResult::advance();
157+
158+
source = transferWrite.getSource().getDefiningOp();
159+
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
160+
return WalkResult::advance();
161+
155162
// TODO: may want to memoize this information for performance but it
156163
// likely gets invalidated often.
157164
DominanceInfo dom(loop);

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -765,10 +765,10 @@ transform.sequence failures(propagate) {
765765

766766
// CHECK-LABEL: func.func @no_hoisting_collapse_shape
767767
// CHECK: scf.for {{.*}} {
768-
// CHECK: vector.transfer_write
769-
// CHECK: vector.transfer_read
770-
// CHECK: vector.transfer_write
771-
// CHECK: }
768+
// CHECK: vector.transfer_write {{.*}} : vector<4xi32>, memref<4xi32>
769+
// CHECK-NEXT: vector.transfer_read {{.*}} : memref<1x4x1xi32>, vector<1x4x1xi32>
770+
// CHECK-NEXT: vector.transfer_write {{.*}} : vector<1x4x1xi32>, memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>>
771+
// CHECK-NEXT: }
772772

773773
func.func @no_hoisting_collapse_shape(%in_0: memref<1x20x1xi32>, %1: memref<9x1xi32>, %vec: vector<4xi32>) {
774774
%c0_i32 = arith.constant 0 : i32
@@ -827,3 +827,48 @@ transform.sequence failures(propagate) {
827827
transform.structured.hoist_redundant_vector_transfers %0
828828
: (!transform.any_op) -> !transform.any_op
829829
}
830+
831+
// -----
832+
833+
// Regression test - hoisting the following `vector.transfer_{read|write}` pair
834+
// would not be safe:
835+
// %lhs = vector.transfer_read %collapsed_1[%c0]
836+
// vector.transfer_write %op, %collapsed_1[%c0]
837+
// That's because the following `vector.transfer_read` reads from the same
838+
// memory (i.e. `%collapsed_1` and `%collapsed_2` alias):
839+
// %acc = vector.transfer_read %collapsed_2[%c0]
840+
841+
// CHECK-LABEL: func.func @no_hoisting_write_to_memref
842+
// CHECK: scf.for {{.*}} {
843+
// CHECK: vector.transfer_read {{.*}} : memref<2xi32>, vector<1xi32>
844+
// CHECK-NEXT: vector.transfer_read {{.*}} : memref<2xi32>, vector<1xi32>
845+
// CHECK-NEXT: vector.outerproduct {{.*}} : vector<1xi32>, i32
846+
// CHECK-NEXT: vector.transfer_write {{.*}} : vector<1xi32>, memref<2xi32>
847+
// CHECK-NEXT: }
848+
849+
func.func @no_hoisting_write_to_memref(%rhs: i32, %arg1: vector<1xi32>) {
850+
%c0_i32 = arith.constant 0 : i32
851+
%c0 = arith.constant 0 : index
852+
%c1 = arith.constant 1 : index
853+
%c4 = arith.constant 4 : index
854+
%c20 = arith.constant 20 : index
855+
%alloca = memref.alloca() {alignment = 64 : i64} : memref<1x1x2xi32>
856+
%cast = memref.cast %alloca : memref<1x1x2xi32> to memref<1x1x2xi32>
857+
%collapsed_1 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32>
858+
scf.for %_ = %c0 to %c20 step %c4 {
859+
%collapsed_2 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32>
860+
%lhs = vector.transfer_read %collapsed_1[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32>
861+
%acc = vector.transfer_read %collapsed_2[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32>
862+
%op = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<1xi32>, i32
863+
vector.transfer_write %op, %collapsed_1[%c0] {in_bounds = [true]} : vector<1xi32>, memref<2xi32>
864+
}
865+
return
866+
}
867+
868+
transform.sequence failures(propagate) {
869+
^bb1(%arg1: !transform.any_op):
870+
%0 = transform.structured.match ops{["func.func"]} in %arg1
871+
: (!transform.any_op) -> !transform.any_op
872+
transform.structured.hoist_redundant_vector_transfers %0
873+
: (!transform.any_op) -> !transform.any_op
874+
}

0 commit comments

Comments
 (0)