Skip to content

Commit 3db8207

Browse files
committed
fixup! [mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases
1. Relax the conditions in the case of `memref.assume_alignment`. This unblocks hoisting for examples like the one reported in #144825 (see the link in the top post: "detailed example pls refer to example"). 2. When checking the source operand for the xfer Ops, we only need to look at either xfer_write or xfer_read (we already know that the source is identical). The corresponding logic has been simplified.
1 parent f7218d7 commit 3db8207

File tree

2 files changed

+70
-15
lines changed

2 files changed

+70
-15
lines changed

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -307,32 +307,41 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
307307
// 3. no other operations in the loop access the same memref except
308308
// for transfer_read/transfer_write accessing statically disjoint
309309
// slices.
310+
311+
// Check 1.
310312
if (transferRead.getIndices() != transferWrite.getIndices() ||
311313
transferRead.getVectorType() != transferWrite.getVectorType() ||
312314
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
313315
return WalkResult::advance();
314316

315-
// Check 2. for xfer_read
316-
auto *source = transferRead.getBase().getDefiningOp();
317-
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
318-
return WalkResult::advance();
319-
317+
// Check 2. Note, since both xfer Ops share the source, we only need to look at
318+
// one of them.
320319
auto base = transferRead.getBase();
321-
for (auto *user : base.getUsers())
322-
if (isa_and_nonnull<ViewLikeOpInterface>(user))
320+
auto *source = base.getDefiningOp();
321+
if (source) {
322+
// NOTE: We treat `memref.assume_alignment` as a special case:
323+
// 1. If it has exactly two uses then these have to be the xfer Ops
324+
// being looked at.
325+
// 2. Otherwise, there are other users that we should take into
326+
// account
327+
// In the case of 1., it is safe to look past AssumeAlignmentOp,
328+
// i.e. at the defining Op of the input MemRef, provided that:
329+
// * the original MemRef has only one use (i.e.
330+
// `memref.assume_alignment`)
331+
if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
332+
Value memPreAlignment = assume.getMemref();
333+
if (base.hasNUses(2) && memPreAlignment.hasOneUse())
334+
source = memPreAlignment.getDefiningOp();
335+
}
336+
if (isa_and_nonnull<ViewLikeOpInterface>(source))
323337
return WalkResult::advance();
338+
}
324339

325-
// Check 2. for xfer_wrire
326-
source = transferWrite.getBase().getDefiningOp();
327-
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
328-
return WalkResult::advance();
329-
330-
base = transferWrite.getBase();
331340
for (auto *user : base.getUsers())
332341
if (isa_and_nonnull<ViewLikeOpInterface>(user))
333342
return WalkResult::advance();
334343

335-
// Check 1. + 3.
344+
// Check 3.
336345
// TODO: may want to memoize this information for performance but it
337346
// likely gets invalidated often.
338347
DominanceInfo dom(loop);
@@ -371,7 +380,8 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
371380
// Hoist write after.
372381
transferWrite->moveAfter(loop);
373382

374-
// Rewrite `loop` with new yields by cloning and erase the original loop.
383+
// Rewrite `loop` with new yields by cloning and erase the original
384+
// loop.
375385
IRRewriter rewriter(transferRead.getContext());
376386
NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
377387
ArrayRef<BlockArgument> newBBArgs) {

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,51 @@ module attributes {transform.with_named_sequence} {
133133

134134
// -----
135135

136+
// Similar as the example above, but the memory access is done via
137+
// memref.assume_alignment. Hoisting is safe as the only users of the
138+
// "allignment" Op are the xfer Ops within the loop that we want to hoist.
139+
140+
// CHECK-LABEL: func.func @hoist_basic_vector_xfer_pair_with_assume_align(
141+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
142+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
143+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
144+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
145+
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
146+
func.func @hoist_basic_vector_xfer_pair_with_assume_align(
147+
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
148+
%c0 = arith.constant 0 : index
149+
%pad = arith.constant 0.0 : f32
150+
151+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
152+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
153+
// CHECK: %[[AA:.*]] = memref.assume_alignment %[[MEM]], 4 : memref<?x?xf32>
154+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[AA]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
155+
// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
156+
// CHECK: %[[USE:.*]] = "some_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
157+
// CHECK: }
158+
// CHECK: vector.transfer_write %[[SCF]], %[[AA]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
159+
160+
%aa = memref.assume_alignment %mem, 4 : memref<?x?xf32>
161+
scf.for %i = %lb to %ub step %step {
162+
%r0 = vector.transfer_read %aa[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
163+
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
164+
vector.transfer_write %u0, %aa[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
165+
}
166+
return
167+
}
168+
169+
module attributes {transform.with_named_sequence} {
170+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
171+
%0 = transform.structured.match ops{["func.func"]} in %arg1
172+
: (!transform.any_op) -> !transform.any_op
173+
transform.structured.hoist_redundant_vector_transfers %0
174+
: (!transform.any_op) -> !transform.any_op
175+
transform.yield
176+
}
177+
}
178+
179+
// -----
180+
136181
// CHECK-LABEL: func @hoist_vector_transfer_pairs(
137182
// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
138183
// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,

0 commit comments

Comments
 (0)