Skip to content

Commit 35b292e

Browse files
authored
[mlir][Hoisting] Hoisting vector.extract/vector.broadcast pairs (#86108)
This transformation, inspired by what is done in hoist_redundant_transfers, hoists pairs of extract/broadcast operations out of scf.for loops. It changes a loop of the form: ``` %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) { %e = vector.extract %iarg : t1 to t2 %u = "some_use"(%e) : (t2) -> t2 %b = vector.broadcast %u : t2 to t1 scf.yield %b : t1 } ``` into the following: ``` %e = vector.extract %v: t1 to t2 %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) { %u' = "some_use"(%iarg) : (t2) -> t2 scf.yield %u' : t2 } %res = vector.broadcast %res' : t2 to t1 ```
1 parent e8b31fb commit 35b292e

File tree

5 files changed

+297
-0
lines changed

5 files changed

+297
-0
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2206,6 +2206,42 @@ def HoistRedundantVectorTransfersOp :
22062206
}];
22072207
}
22082208

2209+
//===----------------------------------------------------------------------===//
2210+
// HoistRedundantVectorBroadcastsOp
2211+
//===----------------------------------------------------------------------===//
2212+
2213+
def HoistRedundantVectorBroadcastsOp :
2214+
Op<Transform_Dialect, "structured.hoist_redundant_vector_broadcasts",
2215+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
2216+
TransformEachOpTrait, TransformOpInterface,
2217+
ReportTrackingListenerFailuresOpTrait]> {
2218+
let description = [{
2219+
Hoist vector.extract / vector.broadcasts pairs out of immediately
2220+
enclosing scf::ForOp iteratively.
2221+
2222+
#### Return modes:
2223+
2224+
The operation always succeeds and returns a handle to the transformed
2225+
function op.
2226+
}];
2227+
2228+
let arguments = (ins TransformHandleTypeInterface:$target);
2229+
let results = (outs TransformHandleTypeInterface:$transformed);
2230+
2231+
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
2232+
2233+
let builders = [
2234+
OpBuilder<(ins "Value":$target)>,
2235+
];
2236+
let extraClassDeclaration = [{
2237+
::mlir::DiagnosedSilenceableFailure applyToOne(
2238+
::mlir::transform::TransformRewriter &rewriter,
2239+
::mlir::Operation *target,
2240+
::mlir::transform::ApplyToEachResultList &results,
2241+
::mlir::transform::TransformState &state);
2242+
}];
2243+
}
2244+
22092245
//===----------------------------------------------------------------------===//
22102246
// ConvertConv2DToImg2ColOp
22112247
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ namespace linalg {
4343
/// when used on distributed loops with memref semantics!
4444
void hoistRedundantVectorTransfers(Operation *root);
4545

46+
/// Hoist vector.extract/vector.broadcast pairs out of immediately enclosing
47+
/// scf::ForOp iteratively, if the following conditions are met:
48+
/// 1. The vector.extract operation is applied on an iter_argument, and no
49+
/// other operator is using this argument in the body of the loop.
50+
/// 2. The position of the vector.extract is either a static value, or defined
51+
/// outside of the loop.
52+
/// 3. The vector.broadcast operation is yielded by the loop.
53+
/// To improve hoisting opportunities, call the `moveLoopInvariantCode` helper
54+
/// function on the candidate loop above which to hoist.
55+
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root);
56+
4657
} // namespace linalg
4758
} // namespace mlir
4859

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3306,6 +3306,21 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
33063306
return DiagnosedSilenceableFailure::success();
33073307
}
33083308

3309+
//===----------------------------------------------------------------------===//
3310+
// HoistRedundantVectorBroadcastsOp
3311+
//===----------------------------------------------------------------------===//
3312+
3313+
DiagnosedSilenceableFailure
3314+
transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3315+
transform::TransformRewriter &rewriter, mlir::Operation *target,
3316+
transform::ApplyToEachResultList &results,
3317+
transform::TransformState &state) {
3318+
rewriter.setInsertionPoint(target);
3319+
linalg::hoistRedundantVectorBroadcasts(rewriter, target);
3320+
results.push_back(target);
3321+
return DiagnosedSilenceableFailure::success();
3322+
}
3323+
33093324
//===----------------------------------------------------------------------===//
33103325
// ConvertConv2DToImg2ColOp.
33113326
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,132 @@ using llvm::dbgs;
4343
using namespace mlir;
4444
using namespace mlir::linalg;
4545

46+
/// Replace `loop` with a new loop that has a different init operand at
47+
/// position `index`. The body of this loop is moved over to the new loop.
48+
///
49+
/// `newInitOperands` specifies the replacement "init" operands.
50+
/// `newYieldValue` is the replacement yield value of the loop at position
51+
/// `index`.
52+
static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
53+
scf::ForOp loop,
54+
Value newInitOperand,
55+
unsigned index,
56+
Value newYieldValue) {
57+
OpBuilder::InsertionGuard g(rewriter);
58+
rewriter.setInsertionPoint(loop.getOperation());
59+
auto inits = llvm::to_vector(loop.getInits());
60+
61+
// Replace the init value with the new operand.
62+
assert(index < inits.size());
63+
inits[index] = newInitOperand;
64+
65+
scf::ForOp newLoop = rewriter.create<scf::ForOp>(
66+
loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
67+
inits, [](OpBuilder &, Location, Value, ValueRange) {});
68+
69+
// Generate the new yield with the replaced operand.
70+
auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
71+
yieldOp.setOperand(index, newYieldValue);
72+
73+
// Move the loop body to the new op.
74+
rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
75+
newLoop.getBody()->getArguments());
76+
77+
// Replace the old loop.
78+
rewriter.replaceOp(loop.getOperation(), newLoop->getResults());
79+
return newLoop;
80+
}
81+
82+
// Hoist out a pair of corresponding vector.extract+vector.broadcast
83+
// operations. This function transforms a loop like this:
84+
// %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
85+
// %e = vector.extract %iarg : t1 to t2
86+
// %u = "some_use"(%e) : (t2) -> t2
87+
// %b = vector.broadcast %u : t2 to t1
88+
// scf.yield %b : t1
89+
// }
90+
// into the following:
91+
// %e = vector.extract %v: t1 to t2
92+
// %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
93+
// %u' = "some_use"(%iarg) : (t2) -> t2
94+
// scf.yield %u' : t2
95+
// }
96+
// %res = vector.broadcast %res' : t2 to t1
97+
void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
98+
Operation *root) {
99+
bool changed = true;
100+
while (changed) {
101+
changed = false;
102+
// First move loop invariant ops outside of their loop. This needs to be
103+
// done before as we cannot move ops without interrupting the function walk.
104+
root->walk(
105+
[&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
106+
107+
root->walk([&](vector::ExtractOp extractOp) {
108+
LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
109+
<< *extractOp.getOperation() << "\n");
110+
111+
auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
112+
if (!loop)
113+
return WalkResult::advance();
114+
115+
// Check that the vector to extract from is a BlockArgument.
116+
auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
117+
if (!blockArg)
118+
return WalkResult::advance();
119+
120+
// Check that the blockArg is an iter_arg of the loop.
121+
OpOperand *initArg = loop.getTiedLoopInit(blockArg);
122+
if (!initArg)
123+
return WalkResult::advance();
124+
125+
// If the iter_arg does not have only one use, it won't be possible to
126+
// hoist the extractOp out.
127+
if (!blockArg.hasOneUse())
128+
return WalkResult::advance();
129+
130+
unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
131+
132+
// Check that the loop yields a broadcast that has just one use.
133+
Operation *yieldedVal =
134+
loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
135+
auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
136+
if (!broadcast || !broadcast.getResult().hasOneUse())
137+
return WalkResult::advance();
138+
139+
LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
140+
141+
Type broadcastInputType = broadcast.getSourceType();
142+
if (broadcastInputType != extractOp.getType())
143+
return WalkResult::advance();
144+
145+
// The position of the extract must be defined outside of the loop if
146+
// it is dynamic.
147+
for (auto operand : extractOp.getDynamicPosition())
148+
if (!loop.isDefinedOutsideOfLoop(operand))
149+
return WalkResult::advance();
150+
151+
rewriter.modifyOpInPlace(broadcast, [&] {
152+
extractOp.getVectorMutable().assign(initArg->get());
153+
});
154+
loop.moveOutOfLoop(extractOp);
155+
rewriter.moveOpAfter(broadcast, loop);
156+
157+
scf::ForOp newLoop = replaceWithDifferentYield(
158+
rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
159+
160+
LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
161+
162+
rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
163+
rewriter.modifyOpInPlace(
164+
broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); });
165+
166+
changed = true;
167+
return WalkResult::interrupt();
168+
});
169+
}
170+
}
171+
46172
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
47173
LoopLikeOpInterface loop) {
48174
Value source = transferRead.getSource();

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,3 +565,112 @@ module attributes {transform.with_named_sequence} {
565565
transform.yield
566566
}
567567
}
568+
569+
// -----
570+
571+
// Test hoisting of vector.extract/vector.broadcast pairs
572+
573+
// CHECK-LABEL: func.func @hoist_vector_broadcasts
574+
// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
575+
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][0] : vector<4xf32> from vector<3x4xf32>
576+
// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
577+
// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
578+
// CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32>
579+
// CHECK-NEXT: }
580+
// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
581+
// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
582+
583+
func.func @hoist_vector_broadcasts(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
584+
%bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
585+
%extract = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
586+
%use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
587+
%broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32>
588+
scf.yield %broadcast : vector<3x4xf32>
589+
}
590+
return %bcast_vec : vector<3x4xf32>
591+
}
592+
593+
module attributes {transform.with_named_sequence} {
594+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
595+
%0 = transform.structured.match ops{["func.func"]} in %arg1
596+
: (!transform.any_op) -> !transform.any_op
597+
transform.structured.hoist_redundant_vector_broadcasts %0
598+
: (!transform.any_op) -> !transform.any_op
599+
transform.yield
600+
}
601+
}
602+
603+
// -----
604+
605+
// Test hoisting of vector.extract/vector.broadcast pairs with dynamic position
606+
607+
// CHECK-LABEL: func.func @hoist_vector_broadcasts
608+
// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
609+
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32>
610+
// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
611+
// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
612+
// CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32>
613+
// CHECK-NEXT: }
614+
// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
615+
// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
616+
617+
func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
618+
%bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
619+
%extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32>
620+
%use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
621+
%broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32>
622+
scf.yield %broadcast : vector<3x4xf32>
623+
}
624+
return %bcast_vec : vector<3x4xf32>
625+
}
626+
627+
module attributes {transform.with_named_sequence} {
628+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
629+
%0 = transform.structured.match ops{["func.func"]} in %arg1
630+
: (!transform.any_op) -> !transform.any_op
631+
transform.structured.hoist_redundant_vector_broadcasts %0
632+
: (!transform.any_op) -> !transform.any_op
633+
transform.yield
634+
}
635+
}
636+
637+
// -----
638+
639+
// Test hoisting of vector.extract/vector.broadcast pairs with multiple iter_args
640+
641+
// CHECK-LABEL: func.func @hoist_vector_broadcasts_multiple
642+
// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
643+
// CHECK-SAME: %[[VEC2:.+]]: vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
644+
// CHECK-DAG: %[[EXTRACT1:.+]] = vector.extract %[[VEC1]][0] : vector<4xf32> from vector<3x4xf32>
645+
// CHECK-DAG: %[[EXTRACT2:.+]] = vector.extract %[[VEC2]][1] : vector<5xf32> from vector<3x5xf32>
646+
// CHECK-NEXT: %[[LOOP:.+]]:2 = scf.for {{.*}} {
647+
// CHECK-DAG: %[[USE1:.+]] = "some_use1"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
648+
// CHECK-DAG: %[[USE2:.+]] = "some_use2"({{.*}}) : (vector<5xf32>) -> vector<5xf32>
649+
// CHECK-NEXT: scf.yield %[[USE1]], %[[USE2]] : vector<4xf32>, vector<5xf32>
650+
// CHECK-NEXT: }
651+
// CHECK-DAG: %[[BCAST1:.+]] = vector.broadcast %[[LOOP]]#0 : vector<4xf32> to vector<3x4xf32>
652+
// CHECK-DAG: %[[BCAST2:.+]] = vector.broadcast %[[LOOP]]#1 : vector<5xf32> to vector<3x5xf32>
653+
// CHECK-NEXT: return %[[BCAST1]], %[[BCAST2]] : vector<3x4xf32>, vector<3x5xf32>
654+
655+
func.func @hoist_vector_broadcasts_multiple(%lb : index, %ub : index, %step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
656+
%bcast_vec:2 = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec1, %iarg2 = %vec2) -> (vector<3x4xf32>, vector<3x5xf32>) {
657+
%extract1 = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
658+
%extract2 = vector.extract %iarg2[1] : vector<5xf32> from vector<3x5xf32>
659+
%use1 = "some_use1"(%extract1) : (vector<4xf32>) -> vector<4xf32>
660+
%use2 = "some_use2"(%extract2) : (vector<5xf32>) -> vector<5xf32>
661+
%broadcast1 = vector.broadcast %use1 : vector<4xf32> to vector<3x4xf32>
662+
%broadcast2 = vector.broadcast %use2 : vector<5xf32> to vector<3x5xf32>
663+
scf.yield %broadcast1, %broadcast2 : vector<3x4xf32>,vector<3x5xf32>
664+
}
665+
return %bcast_vec#0, %bcast_vec#1 : vector<3x4xf32>, vector<3x5xf32>
666+
}
667+
668+
module attributes {transform.with_named_sequence} {
669+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
670+
%0 = transform.structured.match ops{["func.func"]} in %arg1
671+
: (!transform.any_op) -> !transform.any_op
672+
transform.structured.hoist_redundant_vector_broadcasts %0
673+
: (!transform.any_op) -> !transform.any_op
674+
transform.yield
675+
}
676+
}

0 commit comments

Comments
 (0)