Skip to content

[mlir][Hoisting] Hoisting vector.extract/vector.broadcast pairs #86108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2206,6 +2206,42 @@ def HoistRedundantVectorTransfersOp :
}];
}

//===----------------------------------------------------------------------===//
// HoistRedundantVectorBroadcastsOp
//===----------------------------------------------------------------------===//

def HoistRedundantVectorBroadcastsOp :
Op<Transform_Dialect, "structured.hoist_redundant_vector_broadcasts",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformEachOpTrait, TransformOpInterface,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Hoist vector.extract / vector.broadcasts pairs out of immediately
enclosing scf::ForOp iteratively.

#### Return modes:

The operation always succeeds and returns a handle to the transformed
function op.
}];

let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$transformed);

let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";

let builders = [
OpBuilder<(ins "Value":$target)>,
];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

//===----------------------------------------------------------------------===//
// ConvertConv2DToImg2ColOp
//===----------------------------------------------------------------------===//
Expand Down
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ namespace linalg {
/// when used on distributed loops with memref semantics!
void hoistRedundantVectorTransfers(Operation *root);

/// Hoist vector.extract/vector.broadcast pairs out of immediately enclosing
/// scf::ForOp iteratively, if the following conditions are met:
/// 1. The vector.extract operation is applied on an iter_argument, and no
/// other operator is using this argument in the body of the loop.
/// 2. The position of the vector.extract is either a static value, or defined
/// outside of the loop.
/// 3. The vector.broadcast operation is yielded by the loop.
/// To improve hoisting opportunities, call the `moveLoopInvariantCode` helper
/// function on the candidate loop above which to hoist.
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root);

} // namespace linalg
} // namespace mlir

Expand Down
15 changes: 15 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3306,6 +3306,21 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// HoistRedundantVectorBroadcastsOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::HoistRedundantVectorBroadcastsOp::applyToOne(
transform::TransformRewriter &rewriter, mlir::Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
linalg::hoistRedundantVectorBroadcasts(rewriter, target);
results.push_back(target);
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// ConvertConv2DToImg2ColOp.
//===----------------------------------------------------------------------===//
Expand Down
126 changes: 126 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,132 @@ using llvm::dbgs;
using namespace mlir;
using namespace mlir::linalg;

/// Replace `loop` with a new loop that has a different init operand at
/// position `index`. The body of this loop is moved over to the new loop.
///
/// `newInitOperands` specifies the replacement "init" operands.
/// `newYieldValue` is the replacement yield value of the loop at position
/// `index`.
static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
scf::ForOp loop,
Value newInitOperand,
unsigned index,
Value newYieldValue) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(loop.getOperation());
auto inits = llvm::to_vector(loop.getInits());

// Replace the init value with the new operand.
assert(index < inits.size());
inits[index] = newInitOperand;

scf::ForOp newLoop = rewriter.create<scf::ForOp>(
loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
inits, [](OpBuilder &, Location, Value, ValueRange) {});

// Generate the new yield with the replaced operand.
auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
yieldOp.setOperand(index, newYieldValue);

// Move the loop body to the new op.
rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
newLoop.getBody()->getArguments());

// Replace the old loop.
rewriter.replaceOp(loop.getOperation(), newLoop->getResults());
return newLoop;
}

// Hoist out a pair of corresponding vector.extract+vector.broadcast
// operations. This function transforms a loop like this:
// %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
// %e = vector.extract %iarg : t1 to t2
// %u = "some_use"(%e) : (t2) -> t2
// %b = vector.broadcast %u : t2 to t1
// scf.yield %b : t1
// }
// into the following:
// %e = vector.extract %v: t1 to t2
// %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
// %u' = "some_use"(%iarg) : (t2) -> t2
// scf.yield %u' : t2
// }
// %res = vector.broadcast %res' : t2 to t1
void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
Operation *root) {
bool changed = true;
while (changed) {
changed = false;
// First move loop invariant ops outside of their loop. This needs to be
// done before as we cannot move ops without interrupting the function walk.
root->walk(
[&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });

root->walk([&](vector::ExtractOp extractOp) {
LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
<< *extractOp.getOperation() << "\n");

auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
if (!loop)
return WalkResult::advance();

// Check that the vector to extract from is a BlockArgument.
auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
if (!blockArg)
return WalkResult::advance();

// Check that the blockArg is an iter_arg of the loop.
OpOperand *initArg = loop.getTiedLoopInit(blockArg);
if (!initArg)
return WalkResult::advance();

// If the iter_arg does not have only one use, it won't be possible to
// hoist the extractOp out.
if (!blockArg.hasOneUse())
return WalkResult::advance();

unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();

// Check that the loop yields a broadcast that has just one use.
Operation *yieldedVal =
loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
if (!broadcast || !broadcast.getResult().hasOneUse())
return WalkResult::advance();

LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");

Type broadcastInputType = broadcast.getSourceType();
if (broadcastInputType != extractOp.getType())
return WalkResult::advance();

// The position of the extract must be defined outside of the loop if
// it is dynamic.
for (auto operand : extractOp.getDynamicPosition())
if (!loop.isDefinedOutsideOfLoop(operand))
return WalkResult::advance();

rewriter.modifyOpInPlace(broadcast, [&] {
extractOp.getVectorMutable().assign(initArg->get());
});
loop.moveOutOfLoop(extractOp);
rewriter.moveOpAfter(broadcast, loop);

scf::ForOp newLoop = replaceWithDifferentYield(
rewriter, loop, extractOp.getResult(), index, broadcast.getSource());

LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");

rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
rewriter.modifyOpInPlace(
broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); });

changed = true;
return WalkResult::interrupt();
});
}
}

static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
LoopLikeOpInterface loop) {
Value source = transferRead.getSource();
Expand Down
109 changes: 109 additions & 0 deletions mlir/test/Dialect/Linalg/hoisting.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,112 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

// Test hoisting of vector.extract/vector.broadcast pairs

// CHECK-LABEL: func.func @hoist_vector_broadcasts
// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][0] : vector<4xf32> from vector<3x4xf32>
// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
// CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>

func.func @hoist_vector_broadcasts(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
%bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
%extract = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
%use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
%broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32>
scf.yield %broadcast : vector<3x4xf32>
}
return %bcast_vec : vector<3x4xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_broadcasts %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// Test hoisting of vector.extract/vector.broadcast pairs with dynamic position

// CHECK-LABEL: func.func @hoist_vector_broadcasts
// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32>
// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
// CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>

func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
%bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
%extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32>
%use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
%broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32>
scf.yield %broadcast : vector<3x4xf32>
}
return %bcast_vec : vector<3x4xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_broadcasts %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// Test hoisting of vector.extract/vector.broadcast pairs with multiple iter_args

// CHECK-LABEL: func.func @hoist_vector_broadcasts_multiple
// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
// CHECK-SAME: %[[VEC2:.+]]: vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
// CHECK-DAG: %[[EXTRACT1:.+]] = vector.extract %[[VEC1]][0] : vector<4xf32> from vector<3x4xf32>
// CHECK-DAG: %[[EXTRACT2:.+]] = vector.extract %[[VEC2]][1] : vector<5xf32> from vector<3x5xf32>
// CHECK-NEXT: %[[LOOP:.+]]:2 = scf.for {{.*}} {
// CHECK-DAG: %[[USE1:.+]] = "some_use1"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
// CHECK-DAG: %[[USE2:.+]] = "some_use2"({{.*}}) : (vector<5xf32>) -> vector<5xf32>
// CHECK-NEXT: scf.yield %[[USE1]], %[[USE2]] : vector<4xf32>, vector<5xf32>
// CHECK-NEXT: }
// CHECK-DAG: %[[BCAST1:.+]] = vector.broadcast %[[LOOP]]#0 : vector<4xf32> to vector<3x4xf32>
// CHECK-DAG: %[[BCAST2:.+]] = vector.broadcast %[[LOOP]]#1 : vector<5xf32> to vector<3x5xf32>
// CHECK-NEXT: return %[[BCAST1]], %[[BCAST2]] : vector<3x4xf32>, vector<3x5xf32>

func.func @hoist_vector_broadcasts_multiple(%lb : index, %ub : index, %step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
%bcast_vec:2 = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec1, %iarg2 = %vec2) -> (vector<3x4xf32>, vector<3x5xf32>) {
%extract1 = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
%extract2 = vector.extract %iarg2[1] : vector<5xf32> from vector<3x5xf32>
%use1 = "some_use1"(%extract1) : (vector<4xf32>) -> vector<4xf32>
%use2 = "some_use2"(%extract2) : (vector<5xf32>) -> vector<5xf32>
%broadcast1 = vector.broadcast %use1 : vector<4xf32> to vector<3x4xf32>
%broadcast2 = vector.broadcast %use2 : vector<5xf32> to vector<3x5xf32>
scf.yield %broadcast1, %broadcast2 : vector<3x4xf32>,vector<3x5xf32>
}
return %bcast_vec#0, %bcast_vec#1 : vector<3x4xf32>, vector<3x5xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.hoist_redundant_vector_broadcasts %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}