Skip to content

Commit 7f58588

Browse files
committed
fixup! fixup! fixup! [mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases
Incorporate suggestions from HanHan, update the description.
1 parent a671f3d commit 7f58588

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -319,15 +319,17 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
319319
auto base = transferRead.getBase();
320320
auto *source = base.getDefiningOp();
321321
if (source) {
322-
// NOTE: We treat `memref.assume_alignment` as a special case:
323-
// 1. If it has exactly two uses then these have to be the xfer Ops
324-
// being looked at.
325-
// 2. Otherwise, there are other users that we should take into
326-
// account
327-
// In the case of 1., it is safe to look past AssumeAlignmentOp,
328-
// i.e. at the defining Op of the input MemRef, provided that:
329-
// * the original MemRef has only one use (i.e.
330-
// `memref.assume_alignment`)
322+
// NOTE: We treat `memref.assume_alignment` as a special case.
323+
//
324+
// The idea is that it is safe to look past AssumeAlignmemtOp (i.e.
325+
// MemRef _before_ alignment) iff:
326+
// 1. It has exactly two uses (these have to be the xfer Ops
327+
// being looked at).
328+
// 2. The original MemRef has only one use (i.e.
329+
// AssumeAlignmentOp).
330+
//
331+
// Relaxing these conditions will most likely require proper alias
332+
// analysis.
331333
if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
332334
Value memPreAlignment = assume.getMemref();
333335
auto numInLoopUses =
@@ -342,9 +344,8 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
342344
return WalkResult::advance();
343345
}
344346

345-
for (auto *user : base.getUsers())
346-
if (isa_and_nonnull<ViewLikeOpInterface>(user))
347-
return WalkResult::advance();
347+
if (llvm::any_of(base.getUsers(), llvm::IsaPred<ViewLikeOpInterface>))
348+
return WalkResult::advance();
348349

349350
// Check 3.
350351
// TODO: may want to memoize this information for performance but it

0 commit comments

Comments
 (0)