Skip to content

[mlir][linalg] Prevent hoisting of transfer pairs in the presence of aliases #145235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,23 +303,51 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
// 1. indices, vector type and permutation map are the same (i.e., the
// transfer_read/transfer_write ops are matching),
// 2. source operands for transfer.{read|write} do not originate from
// Ops implementing ViewLikeOpInterface.
// nor have users that are Ops implementing ViewLikeOpInterface.
// 3. no other operations in the loop access the same memref except
// for transfer_read/transfer_write accessing statically disjoint
// slices.

// Check 1.
if (transferRead.getIndices() != transferWrite.getIndices() ||
transferRead.getVectorType() != transferWrite.getVectorType() ||
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
return WalkResult::advance();

auto *source = transferRead.getBase().getDefiningOp();
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
return WalkResult::advance();
// Check 2. Note, since both xfer Ops share the source, we only need to
// look at one of them.
auto base = transferRead.getBase();
auto *source = base.getDefiningOp();
if (source) {
// NOTE: We treat `memref.assume_alignment` as a special case.
//
// The idea is that it is safe to look past AssumeAlignmemtOp (i.e.
// MemRef _before_ alignment) iff:
// 1. It has exactly two uses (these have to be the xfer Ops
// being looked at).
// 2. The original MemRef has only one use (i.e.
// AssumeAlignmentOp).
//
// Relaxing these conditions will most likely require proper alias
// analysis.
if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
Value memPreAlignment = assume.getMemref();
auto numInLoopUses =
llvm::count_if(base.getUses(), [&loop](OpOperand &use) {
return loop->isAncestor(use.getOwner());
});

if (numInLoopUses && memPreAlignment.hasOneUse())
source = memPreAlignment.getDefiningOp();
}
if (isa_and_nonnull<ViewLikeOpInterface>(source))
return WalkResult::advance();
}

source = transferWrite.getBase().getDefiningOp();
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
if (llvm::any_of(base.getUsers(), llvm::IsaPred<ViewLikeOpInterface>))
return WalkResult::advance();

// Check 3.
// TODO: may want to memoize this information for performance but it
// likely gets invalidated often.
DominanceInfo dom(loop);
Expand Down Expand Up @@ -358,7 +386,8 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
// Hoist write after.
transferWrite->moveAfter(loop);

// Rewrite `loop` with new yields by cloning and erase the original loop.
// Rewrite `loop` with new yields by cloning and erase the original
// loop.
IRRewriter rewriter(transferRead.getContext());
NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
ArrayRef<BlockArgument> newBBArgs) {
Expand Down
229 changes: 229 additions & 0 deletions mlir/test/Dialect/Linalg/hoisting.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,234 @@
// RUN: mlir-opt -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s

///----------------------------------------------------------------------------------------
/// Tests for vector.transfer_read + vector.transfer_write pairs
///
/// * Nested inside a single loop
// * Indices are constant
///----------------------------------------------------------------------------------------

// The most basic example - hoisting is safe.

// CHECK-LABEL: func.func @hoist_basic_vector_xfer_pair(
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) {
func.func @hoist_basic_vector_xfer_pair(
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32

// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
// CHECK: %[[VAL_6:.*]] = "val_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
// CHECK: scf.yield %[[VAL_6]] : vector<1xf32>
// CHECK: }
// CHECK: vector.transfer_write %[[SCF]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
scf.for %i = %lb to %ub step %step {
%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
%u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// Similar as the example above, but hoisting is no longer safe. That's due to
// an extra xfer_write inside the loop.

// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
func.func @negative_hoist_basic_vector_xfer_pair_extra_write(
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32

// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: vector.transfer_write %[[IN]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
// CHECK: }

scf.for %i = %lb to %ub step %step {
vector.transfer_write %in, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>

%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
%u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// Similar as the example above, but hoisting is no longer safe. That's due to
// an extra xfer_write into _an alias_ of the %mem Op that is used by the
// original xfer pair.

// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
func.func @negative_hoist_basic_vector_xfer_pair_extra_write_into_alias(
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32

// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [1, 1] [1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: vector.transfer_write %[[IN]], %[[SV]][%[[C0]], %[[C0]]] {{.*}} : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
// CHECK: }

%sv = memref.subview %mem[0, 0][1, 1][1, 1] : memref<?x?xf32> to memref<1x1xf32, strided<[?, 1]>>
scf.for %i = %lb to %ub step %step {
vector.transfer_write %in, %sv[%c0, %c0] : vector<1xf32>, memref<1x1xf32, strided<[?, 1]>>

%r0 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
%u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
vector.transfer_write %u0, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// Similar as the example above, but the memory access is done via
// memref.assume_alignment. Hoisting is safe as the only users of the
// "allignment" Op are the xfer Ops within the loop that we want to hoist.

// CHECK-LABEL: func.func @hoist_basic_vector_xfer_pair_with_assume_align(
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
func.func @hoist_basic_vector_xfer_pair_with_assume_align(
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32

// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[AA:.*]] = memref.assume_alignment %[[MEM]], 4 : memref<?x?xf32>
// CHECK: %[[READ:.*]] = vector.transfer_read %[[AA]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
// CHECK: %[[SCF:.*]] = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[INIT:.*]] = %[[READ]]) -> (vector<1xf32>) {
// CHECK: %[[USE:.*]] = "val_use"(%[[INIT]]) : (vector<1xf32>) -> vector<1xf32>
// CHECK: }
// CHECK: vector.transfer_write %[[SCF]], %[[AA]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>

%aa = memref.assume_alignment %mem, 4 : memref<?x?xf32>
scf.for %i = %lb to %ub step %step {
%r0 = vector.transfer_read %aa[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
%u0 = "val_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
vector.transfer_write %u0, %aa[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// Similar as the example above, but hoisting is not safe due to extra memory
// access inside the loop via the original memref.

// CHECK-LABEL: func.func @negative_hoist_basic_vector_xfer_pair_with_assume_align(
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: vector<1xf32>) {
func.func @negative_hoist_basic_vector_xfer_pair_with_assume_align(
%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %in: vector<1xf32>) {
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32

// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[AA:.*]] = memref.assume_alignment %[[MEM]], 4 : memref<?x?xf32>
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: %[[READ:.*]] = vector.transfer_read %[[AA]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
// CHECK: "mem_use"(%[[MEM]])
// CHECK: vector.transfer_write %[[READ]], %[[AA]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
// CHECK: }

%aa = memref.assume_alignment %mem, 4 : memref<?x?xf32>
scf.for %i = %lb to %ub step %step {
%r0 = vector.transfer_read %aa[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
"mem_use"(%mem) : (memref<?x?xf32>) -> ()
vector.transfer_write %r0, %aa[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_transfers %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

///----------------------------------------------------------------------------------------
/// Tests for vector.transfer_read + vector.transfer_write pairs
///
Expand Down