-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Hoisting] Hoisting vector.extract/vector.broadcast pairs #86108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Motivation: during internal development we noticed quite a few instances of such patterns where many loops would start with |
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Steven Varoumas (stevenvar) ChangesThis transformation, inspired by what is done in hoist_redundant_transfers, hoists pairs of extract/broadcast operations out of scf.for loops. It changes a loop of the form:
into the following:
Full diff: https://github.com/llvm/llvm-project/pull/86108.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 4f34016066b4ce..7e1b2894dc0126 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2213,6 +2213,42 @@ def HoistRedundantVectorTransfersOp :
}];
}
+//===----------------------------------------------------------------------===//
+// HoistRedundantVectorBroadcastsOp
+//===----------------------------------------------------------------------===//
+
+def HoistRedundantVectorBroadcastsOp :
+ Op<Transform_Dialect, "structured.hoist_redundant_vector_broadcasts",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformEachOpTrait, TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Hoist vector.extract / vector.broadcasts pairs out of immediately
+ enclosing scf::ForOp iteratively.
+
+ #### Return modes:
+
+ The operation always succeeds and returns a handle to the transformed
+ function op.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>,
+ ];
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::func::FuncOp target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// ConvertConv2DToImg2ColOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
index 186e83a57580f3..11886d4876a97f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
@@ -43,6 +43,8 @@ namespace linalg {
/// when used on distributed loops with memref semantics!
void hoistRedundantVectorTransfers(Operation *root);
+void hoistRedundantVectorBroadcasts(Operation *root);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ecf9983124821a..d7e6a21a565a75 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3224,6 +3224,20 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// HoistRedundantVectorBroadcastsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::HoistRedundantVectorBroadcastsOp::applyToOne(
+ transform::TransformRewriter &rewriter, func::FuncOp target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ linalg::hoistRedundantVectorBroadcasts(target);
+ results.push_back(target);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// ConvertConv2DToImg2ColOp.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 34c9b2c282965c..98521cd745216c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -43,6 +43,120 @@ using llvm::dbgs;
using namespace mlir;
using namespace mlir::linalg;
+scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
+ Value newInitOperand, int index,
+ Value newYieldValue) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(loop.getOperation());
+ auto inits = llvm::to_vector(loop.getInits());
+
+ // Replace the init value with the new operand
+ inits[index] = newInitOperand;
+
+ scf::ForOp newLoop = rewriter.create<scf::ForOp>(
+ loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
+ inits, [](OpBuilder &, Location, Value, ValueRange) {});
+
+ // Generate the new yield with the replaced operand
+ auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
+ yieldOp->getOperand(index).replaceAllUsesWith(newYieldValue);
+
+ // Move the loop body to the new op.
+ rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
+ newLoop.getBody()->getArguments().take_front(
+ loop.getBody()->getNumArguments()));
+
+ // Replace the old loop.
+ rewriter.replaceOp(loop.getOperation(),
+ newLoop->getResults().take_front(loop.getNumResults()));
+ return newLoop;
+}
+
+// Hoist out a pair of corresponding vector.extract+vector.broadcast
+// operations. This function transforms a loop like this:
+// %loop = scf.for _ = _ to _ step _ iter_args(%iterarg = %v) -> (t1) {
+// %e = vector.extract %iterarg : t1 to t2
+// %u = // do something with %e : t2
+// %b = vector.broadcast %u : t2 to t1
+// scf.yield %b : t1
+// }
+// into the following:
+// %e = vector.extract %v: t1 to t2
+// %loop' = scf.for _ = _ to _ step _ iter_args(%iterarg = %e) -> (t2) {
+// %u' = // do something with %iterarg : t2
+// scf.yield %u' : t2
+// }
+// %loop = vector.broadcast %loop' : t2 to t1
+void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
+ bool changed = true;
+ while (changed) {
+ changed = false;
+ // First move loop invariant ops outside of their loop. This needs to be
+ // done before as we cannot move ops without interrupting the function walk.
+ root->walk(
+ [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
+
+ root->walk([&](vector::ExtractOp extractOp) {
+ LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
+ << *extractOp.getOperation() << "\n");
+
+ auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
+ if (!loop)
+ return WalkResult::advance();
+
+ // Check that the vector to extract from is an iter_arg
+ auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
+ if (!blockArg)
+ return WalkResult::advance();
+
+ // If the iter_arg does not have only one use, it won't be possible to
+ // hoist the extractOp out.
+ if (!blockArg.hasOneUse())
+ return WalkResult::advance();
+
+ auto initArg = loop.getTiedLoopInit(blockArg)->get();
+ auto index = blockArg.getArgNumber() - loop.getNumInductionVars();
+
+ // Check that the loop yields a broadcast
+ auto lastOp = loop.getBody()->getTerminator();
+ auto yieldOp = dyn_cast<scf::YieldOp>(lastOp);
+ if (!yieldOp)
+ return WalkResult::advance();
+
+ auto broadcast = dyn_cast<vector::BroadcastOp>(
+ yieldOp->getOperand(index).getDefiningOp());
+
+ LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
+
+ Type broadcastInputType = broadcast.getSourceType();
+ if (broadcastInputType != extractOp.getType())
+ return WalkResult::advance();
+
+ // The position of the extract must be defined outside of the loop if
+ // it is dynamic
+ for (auto operand : extractOp.getDynamicPosition())
+ if (!loop.isDefinedOutsideOfLoop(operand))
+ return WalkResult::advance();
+
+ extractOp.getVectorMutable().assign(initArg);
+ loop.moveOutOfLoop(extractOp);
+ broadcast->moveAfter(loop);
+
+ IRRewriter rewriter(extractOp.getContext());
+ auto newLoop = replaceWithDifferentYield(
+ rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
+
+ LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
+
+ newLoop.getResult(index).replaceAllUsesWith(broadcast);
+ broadcast.getSourceMutable().assign(newLoop.getResult(index));
+
+ changed = true;
+ return WalkResult::interrupt();
+ });
+ }
+}
+
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
LoopLikeOpInterface loop) {
Value source = transferRead.getSource();
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 550ffbc7bab678..5a640348be90cf 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -565,3 +565,112 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// Test hoisting of vector.extract/vector.broadcast pairs
+
+// CHECK-LABEL: func.func @hoist_vector_broadcasts
+// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][0] : vector<4xf32> from vector<3x4xf32>
+// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
+// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
+// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @hoist_vector_broadcasts(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
+ %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
+ %extract = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
+ %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
+ %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32>
+ scf.yield %broadcast : vector<3x4xf32>
+ }
+ return %bcast_vec : vector<3x4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_broadcasts %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Test hoisting of vector.extract/vector.broadcast pairs with dynamic position
+
+// CHECK-LABEL: func.func @hoist_vector_broadcasts
+// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32>
+// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
+// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
+// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
+ %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
+ %extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32>
+ %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
+ %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32>
+ scf.yield %broadcast : vector<3x4xf32>
+ }
+ return %bcast_vec : vector<3x4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_broadcasts %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Test hoisting of vector.extract/vector.broadcast pairs with multiple iter_args
+
+// CHECK-LABEL: func.func @hoist_vector_broadcasts_multiple
+// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
+// CHECK-SAME: %[[VEC2:.+]]: vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
+// CHECK-DAG: %[[EXTRACT1:.+]] = vector.extract %[[VEC1]][0] : vector<4xf32> from vector<3x4xf32>
+// CHECK-DAG: %[[EXTRACT2:.+]] = vector.extract %[[VEC2]][1] : vector<5xf32> from vector<3x5xf32>
+// CHECK-NEXT: %[[LOOP:.+]]:2 = scf.for {{.*}} {
+// CHECK-DAG: %[[USE1:.+]] = "some_use1"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// CHECK-DAG: %[[USE2:.+]] = "some_use2"({{.*}}) : (vector<5xf32>) -> vector<5xf32>
+// CHECK-NEXT: scf.yield %[[USE1]], %[[USE2]] : vector<4xf32>, vector<5xf32>
+// CHECK-NEXT: }
+// CHECK-DAG: %[[BCAST1:.+]] = vector.broadcast %[[LOOP]]#0 : vector<4xf32> to vector<3x4xf32>
+// CHECK-DAG: %[[BCAST2:.+]] = vector.broadcast %[[LOOP]]#1 : vector<5xf32> to vector<3x5xf32>
+// CHECK-NEXT: return %[[BCAST1]], %[[BCAST2]] : vector<3x4xf32>, vector<3x5xf32>
+
+func.func @hoist_vector_broadcasts_multiple(%lb : index, %ub : index, %step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
+ %bcast_vec:2 = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec1, %iarg2 = %vec2) -> (vector<3x4xf32>, vector<3x5xf32>) {
+ %extract1 = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
+ %extract2 = vector.extract %iarg2[1] : vector<5xf32> from vector<3x5xf32>
+ %use1 = "some_use1"(%extract1) : (vector<4xf32>) -> vector<4xf32>
+ %use2 = "some_use2"(%extract2) : (vector<5xf32>) -> vector<5xf32>
+ %broadcast1 = vector.broadcast %use1 : vector<4xf32> to vector<3x4xf32>
+ %broadcast2 = vector.broadcast %use2 : vector<5xf32> to vector<3x5xf32>
+ scf.yield %broadcast1, %broadcast2 : vector<3x4xf32>,vector<3x5xf32>
+ }
+ return %bcast_vec#0, %bcast_vec#1 : vector<3x4xf32>, vector<3x5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.hoist_redundant_vector_broadcasts %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
\ No newline at end of file
|
bc1e38a
to
9516800
Compare
9516800
to
e245656
Compare
Ping |
e245656
to
26ccd91
Compare
Ping |
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
Outdated
Show resolved
Hide resolved
ec49491
to
13d780e
Compare
13d780e
to
5ebc4fb
Compare
e859861
to
3835681
Compare
3835681
to
bf4eb78
Compare
Thank you @ftynse for the quick approval, are you happy with merging this patch? I do not have committer permission |
This transformation, inspired by what is done in hoist_redundant_transfers, hoists pairs of extract/broadcast operations out of scf.for loops.
It changes a loop of the form:
into the following: