Skip to content

Commit f7218d7

Browse files
committed
[mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases
This patch adds additional checks to the hoisting logic to prevent hoisting of `vector.transfer_read`/`vector.transfer_write` pairs when the underlying `memref` has users that introduce aliases via operations implementing `ViewLikeOpInterface`. Note: This may conservatively block some valid hoisting opportunities and could impact performance. However, as demonstrated by the included tests, the current behavior is too permissive and can lead to incorrect transformations. If this change prevents hoisting in cases that are provably safe, please share a minimal repro — I’d be happy to explore ways to relax the check.
1 parent 577199f commit f7218d7

File tree

2 files changed

+147
-1
lines changed

2 files changed

+147
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
303303
// 1. indices, vector type and permutation map are the same (i.e., the
304304
// transfer_read/transfer_write ops are matching),
305305
// 2. source operands for transfer.{read|write} do not originate from
306-
// Ops implementing ViewLikeOpInterface.
306+
// nor have users that are Ops implementing ViewLikeOpInterface.
307307
// 3. no other operations in the loop access the same memref except
308308
// for transfer_read/transfer_write accessing statically disjoint
309309
// slices.
@@ -312,14 +312,27 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
312312
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
313313
return WalkResult::advance();
314314

315+
// Check 2. for xfer_read
315316
auto *source = transferRead.getBase().getDefiningOp();
316317
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
317318
return WalkResult::advance();
318319

320+
auto base = transferRead.getBase();
321+
for (auto *user : base.getUsers())
322+
if (isa_and_nonnull<ViewLikeOpInterface>(user))
323+
return WalkResult::advance();
324+
325+
// Check 2. for xfer_wrire
319326
source = transferWrite.getBase().getDefiningOp();
320327
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
321328
return WalkResult::advance();
322329

330+
base = transferWrite.getBase();
331+
for (auto *user : base.getUsers())
332+
if (isa_and_nonnull<ViewLikeOpInterface>(user))
333+
return WalkResult::advance();
334+
335+
// Check 1. + 3.
323336
// TODO: may want to memoize this information for performance but it
324337
// likely gets invalidated often.
325338
DominanceInfo dom(loop);

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,138 @@
11
// RUN: mlir-opt -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
22

3+
// The most basic example - hoisting is safe.
4+
5+
// CHECK-LABEL: func.func @hoist_basic_vector_xfer_pair(
6+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
7+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
8+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
9+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) {
10+
func.func @hoist_basic_vector_xfer_pair(
11+
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
12+
%c0 = arith.constant 0 : index
13+
%pad = arith.constant 0.0 : f32
14+
15+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
16+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
17+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
18+
// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
19+
// CHECK: %[[VAL_6:.*]] = "some_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
20+
// CHECK: scf.yield %[[VAL_6]] : vector<1xf32>
21+
// CHECK: }
22+
// CHECK: vector.transfer_write %[[SCF]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
23+
scf.for %i = %lb to %ub step %step {
24+
%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
25+
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
26+
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
27+
}
28+
return
29+
}
30+
31+
module attributes {transform.with_named_sequence} {
32+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
33+
%0 = transform.structured.match ops{["func.func"]} in %arg1
34+
: (!transform.any_op) -> !transform.any_op
35+
transform.structured.hoist_redundant_vector_transfers %0
36+
: (!transform.any_op) -> !transform.any_op
37+
transform.yield
38+
}
39+
}
40+
41+
// -----
42+
43+
// Similar as the example above, but hoisting is no longer safe. That's due to
44+
// an extra xfer_write inside the loop.
45+
46+
// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
47+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
48+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
49+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
50+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
51+
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
52+
func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
53+
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
54+
%c0 = arith.constant 0 : index
55+
%pad = arith.constant 0.0 : f32
56+
57+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
58+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
59+
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
60+
// CHECK: vector.transfer_write %[[IN]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
61+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
62+
// CHECK: %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
63+
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
64+
// CHECK: }
65+
66+
scf.for %i = %lb to %ub step %step {
67+
vector.transfer_write %in, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
68+
69+
%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
70+
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
71+
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
72+
}
73+
return
74+
}
75+
76+
module attributes {transform.with_named_sequence} {
77+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
78+
%0 = transform.structured.match ops{["func.func"]} in %arg1
79+
: (!transform.any_op) -> !transform.any_op
80+
transform.structured.hoist_redundant_vector_transfers %0
81+
: (!transform.any_op) -> !transform.any_op
82+
transform.yield
83+
}
84+
}
85+
86+
// -----
87+
88+
// Similar as the example above, but hoisting is no longer safe. That's due to
89+
// an extra xfer_write into _an alias_ of the %mem Op that is used by the
90+
// original xfer pair.
91+
92+
// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
93+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
94+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
95+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
96+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
97+
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
98+
func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
99+
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
100+
%c0 = arith.constant 0 : index
101+
%pad = arith.constant 0.0 : f32
102+
103+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
104+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
105+
// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [1, 1] [1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
106+
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
107+
// CHECK: vector.transfer_write %[[IN]], %[[SV]][%[[C0]], %[[C0]]] {{.*}} : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
108+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
109+
// CHECK: %[[USE:.*]] = "some_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
110+
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
111+
// CHECK: }
112+
113+
%sv = memref.subview %mem[0, 0][1, 1][1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
114+
scf.for %i = %lb to %ub step %step {
115+
vector.transfer_write %in, %sv[%c0, %c0] : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
116+
117+
%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
118+
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
119+
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
120+
}
121+
return
122+
}
123+
124+
module attributes {transform.with_named_sequence} {
125+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
126+
%0 = transform.structured.match ops{["func.func"]} in %arg1
127+
: (!transform.any_op) -> !transform.any_op
128+
transform.structured.hoist_redundant_vector_transfers %0
129+
: (!transform.any_op) -> !transform.any_op
130+
transform.yield
131+
}
132+
}
133+
134+
// -----
135+
3136
// CHECK-LABEL: func @hoist_vector_transfer_pairs(
4137
// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
5138
// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,

0 commit comments

Comments
 (0)