-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][SCF] Add support for loop pipeline peeling for dynamic loops. #106436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9be66a1
df8268d
5603ded
969e8bf
1f17e1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,8 +94,8 @@ struct LoopPipelinerInternal { | |
RewriterBase &rewriter); | ||
/// Emits the epilogue, this creates `maxStage - 1` part which will contain | ||
/// operations from stages [i; maxStage], where i is the part index. | ||
void emitEpilogue(RewriterBase &rewriter, | ||
llvm::SmallVector<Value> &returnValues); | ||
LogicalResult emitEpilogue(RewriterBase &rewriter, | ||
llvm::SmallVector<Value> &returnValues); | ||
}; | ||
|
||
bool LoopPipelinerInternal::initializeLoopInfo( | ||
|
@@ -133,10 +133,6 @@ bool LoopPipelinerInternal::initializeLoopInfo( | |
LDBG("--no epilogue or predicate set -> BAIL"); | ||
return false; | ||
} | ||
if (dynamicLoop && peelEpilogue) { | ||
LDBG("--dynamic loop doesn't support epilogue yet -> BAIL"); | ||
return false; | ||
} | ||
std::vector<std::pair<Operation *, unsigned>> schedule; | ||
options.getScheduleFn(forOp, schedule); | ||
if (schedule.empty()) { | ||
|
@@ -313,10 +309,10 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { | |
}); | ||
int predicateIdx = i - stages[op]; | ||
if (predicates[predicateIdx]) { | ||
OpBuilder::InsertionGuard insertGuard(rewriter); | ||
newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]); | ||
assert(newOp && "failed to predicate op."); | ||
} | ||
rewriter.setInsertionPointAfter(newOp); | ||
if (annotateFn) | ||
annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i); | ||
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { | ||
|
@@ -561,14 +557,14 @@ LogicalResult LoopPipelinerInternal::createKernel( | |
} | ||
|
||
if (predicates[useStage]) { | ||
OpBuilder::InsertionGuard insertGuard(rewriter); | ||
newOp = predicateFn(rewriter, newOp, predicates[useStage]); | ||
if (!newOp) | ||
return failure(); | ||
// Remap the results to the new predicated one. | ||
for (auto values : llvm::zip(op->getResults(), newOp->getResults())) | ||
mapping.map(std::get<0>(values), std::get<1>(values)); | ||
} | ||
rewriter.setInsertionPointAfter(newOp); | ||
if (annotateFn) | ||
annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0); | ||
} | ||
|
@@ -640,70 +636,113 @@ LogicalResult LoopPipelinerInternal::createKernel( | |
return success(); | ||
} | ||
|
||
void LoopPipelinerInternal::emitEpilogue( | ||
RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) { | ||
LogicalResult | ||
LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, | ||
llvm::SmallVector<Value> &returnValues) { | ||
Location loc = forOp.getLoc(); | ||
// Emit different versions of the induction variable. They will be | ||
// removed by dead code if not used. | ||
|
||
// bounds_range = ub - lb | ||
// total_iterations = (bounds_range + step - 1) / step | ||
Type t = lb.getType(); | ||
Value minus1 = | ||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1)); | ||
Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb); | ||
Value rangeIncr = rewriter.create<arith::AddIOp>(loc, boundsRange, step); | ||
Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1); | ||
Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step); | ||
|
||
SmallVector<Value> predicates(maxStage + 1); | ||
for (int64_t i = 0; i < maxStage; i++) { | ||
Location loc = forOp.getLoc(); | ||
Type t = lb.getType(); | ||
Value minusOne = | ||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1)); | ||
// number of iterations = ((ub - 1) - lb) / step | ||
Value totalNumIteration = rewriter.create<arith::DivUIOp>( | ||
loc, | ||
rewriter.create<arith::SubIOp>( | ||
loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb), | ||
step); | ||
// newLastIter = lb + step * ((((ub - 1) - lb) / step) - i) | ||
// iterI = total_iters - 1 - i | ||
// May go negative... | ||
Value minusI = | ||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i)); | ||
Value iterI = rewriter.create<arith::AddIOp>( | ||
loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1), | ||
minusI); | ||
// newLastIter = lb + step * iterI | ||
Value newlastIter = rewriter.create<arith::AddIOp>( | ||
loc, lb, | ||
rewriter.create<arith::MulIOp>( | ||
loc, step, | ||
rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI))); | ||
loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI)); | ||
|
||
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i); | ||
|
||
if (dynamicLoop) { | ||
// pred = iterI >= lb | ||
predicates[i + 1] = rewriter.create<arith::CmpIOp>( | ||
loc, arith::CmpIPredicate::sge, iterI, lb); | ||
} | ||
} | ||
|
||
// Emit `maxStage - 1` epilogue part that includes operations from stages | ||
// [i; maxStage]. | ||
for (int64_t i = 1; i <= maxStage; i++) { | ||
SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size()); | ||
for (Operation *op : opOrder) { | ||
if (stages[op] < i) | ||
continue; | ||
unsigned currentVersion = maxStage - stages[op] + i; | ||
unsigned nextVersion = currentVersion + 1; | ||
Operation *newOp = | ||
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { | ||
auto it = valueMapping.find(newOperand->get()); | ||
if (it != valueMapping.end()) { | ||
Value replacement = it->second[maxStage - stages[op] + i]; | ||
Value replacement = it->second[currentVersion]; | ||
newOperand->set(replacement); | ||
} | ||
}); | ||
if (dynamicLoop) { | ||
OpBuilder::InsertionGuard insertGuard(rewriter); | ||
newOp = predicateFn(rewriter, newOp, predicates[currentVersion]); | ||
if (!newOp) | ||
return failure(); | ||
} | ||
if (annotateFn) | ||
annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1); | ||
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { | ||
setValueMapping(op->getResult(destId), newOp->getResult(destId), | ||
maxStage - stages[op] + i); | ||
|
||
for (auto [opRes, newRes] : | ||
llvm::zip(op->getResults(), newOp->getResults())) { | ||
setValueMapping(opRes, newRes, currentVersion); | ||
// If the value is a loop carried dependency update the loop argument | ||
// mapping and keep track of the last version to replace the original | ||
// forOp uses. | ||
for (OpOperand &operand : | ||
forOp.getBody()->getTerminator()->getOpOperands()) { | ||
if (operand.get() != op->getResult(destId)) | ||
if (operand.get() != opRes) | ||
continue; | ||
unsigned version = maxStage - stages[op] + i + 1; | ||
// If the version is greater than maxStage it means it maps to the | ||
// original forOp returned value. | ||
if (version > maxStage) { | ||
returnValues[operand.getOperandNumber()] = newOp->getResult(destId); | ||
continue; | ||
} | ||
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], | ||
newOp->getResult(destId), version); | ||
unsigned ri = operand.getOperandNumber(); | ||
returnValues[ri] = newRes; | ||
Value mapVal = forOp.getRegionIterArgs()[ri]; | ||
returnMap[ri] = std::make_pair(mapVal, currentVersion); | ||
if (nextVersion <= maxStage) | ||
setValueMapping(mapVal, newRes, nextVersion); | ||
} | ||
} | ||
} | ||
if (dynamicLoop) { | ||
// Select return values from this stage (live outs) based on predication. | ||
// If the stage is valid select the peeled value, else use previous stage | ||
// value. | ||
for (auto pair : llvm::enumerate(returnValues)) { | ||
unsigned ri = pair.index(); | ||
auto [mapVal, currentVersion] = returnMap[ri]; | ||
if (mapVal) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to predicate all the return values? I would think that we could predicate only the values that are later used outside of the loop, otherwise it is OK to speculatively execute. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When maxStage > 2 there are multiple stages peeled. But if Some results may not be used outside loop, and would be optimized away. But since we capture these as we peel each iteration, they feed to the next iteration, and the final set replaces forLoop results. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I see what you mean. This will happen only for dependencies within the same stage, right? For example:
If both ops are in the same stage (say: last), you need to predicate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, versioning takes care of that dependency. But a case where each stage returns a new value based on the old value, requires the select.
I will add an example test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes please include en example I'm not sure I understand in what case that would be needed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Could we have a check for that, instead of adding selects for all the return values? I can imagine removing them may be hard afterwards EDIT: actually I take that back. We have spent some time with @ThomasRaoux analyzing different cases and we agree that predicates are needed for all the return values. I guess the test case for it won't hurt :) But the code looks correct! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For example:
I see now the example predicates every operation using the predicateFn, not just the side-effecting ops. So this becomes:
As you can see every operations is guarded (including ops that do not produce a loop result). And it doesn't really do speculative execution then. If only side-effecting ops are guarded and only results are selected based on stage range, results would be:
And this seems to be what the Prologue logic is doing as well (see line 343). |
||
unsigned nextVersion = currentVersion + 1; | ||
Value pred = predicates[currentVersion]; | ||
Value prevValue = valueMapping[mapVal][currentVersion]; | ||
auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(), | ||
prevValue); | ||
returnValues[ri] = selOp; | ||
if (nextVersion <= maxStage) | ||
setValueMapping(mapVal, selOp, nextVersion); | ||
} | ||
} | ||
} | ||
} | ||
return success(); | ||
} | ||
|
||
void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { | ||
|
@@ -760,7 +799,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, | |
if (options.peelEpilogue) { | ||
// 4. Emit the epilogue after the new forOp. | ||
rewriter.setInsertionPointAfter(newForOp); | ||
pipeliner.emitEpilogue(rewriter, returnValues); | ||
if (failed(pipeliner.emitEpilogue(rewriter, returnValues))) | ||
return failure(); | ||
} | ||
// 5. Erase the original loop and replace the uses with the epilogue output. | ||
if (forOp->getNumResults() > 0) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -764,11 +764,44 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub: | |
// NOEPILOGUE: memref.load %[[A]][%[[IV3]]] : memref<?xf32> | ||
// NOEPILOGUE: scf.yield %[[V2]], %[[L3]] : f32, f32 | ||
|
||
// In case dynamic loop pipelining is off check that the transformation didn't | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to keep a test checking this case (not pipelining when dynamic loop support is turned off?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would need to add a new switch to TestSCFUtils and probably a new test file so we don't run all the tests again without dynamic loop support. Or perhaps add it to the annotate run? Is that acceptable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, not sure if this is worth the effort to validate that the transformation is disabled. I think I'm OK if you'd want to skip it. |
||
// apply. | ||
// Check for predicated epilogue for dynamic loop. | ||
// CHECK-LABEL: dynamic_loop( | ||
// CHECK-NOT: memref.load | ||
// CHECK: scf.for | ||
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}) | ||
// CHECK: memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]] | ||
// CHECK: %[[ADDF_24:.*]] = arith.addf %[[ARG7]], %{{.*}} | ||
// CHECK: %[[MULI_25:.*]] = arith.muli %{{.*}}, %{{.*}} | ||
// CHECK: %[[ADDI_26:.*]] = arith.addi %[[ARG5]], %[[MULI_25]] | ||
// CHECK: %[[LOAD_27:.*]] = memref.load %{{.*}}[%[[ADDI_26]]] | ||
// CHECK: scf.yield %[[ADDF_24]], %[[LOAD_27]] | ||
// CHECK: } | ||
// CHECK: %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}} | ||
// CHECK: %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %{{.*}} | ||
// CHECK: %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %{{.*}}-1 | ||
// CHECK: %[[DIVUI_13:.*]] = arith.divui %[[ADDI_12]], %{{.*}} | ||
// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1 | ||
// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]] | ||
// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]] | ||
// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}} | ||
// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1 | ||
// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1 | ||
// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]] | ||
// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]] | ||
// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}} | ||
// CHECK: scf.if %[[CMPI_17]] { | ||
// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]] | ||
// CHECK: } else { | ||
// CHECK: } | ||
// CHECK: %[[IF_23:.*]] = scf.if %[[CMPI_22]] -> (f32) { | ||
// CHECK: %[[ADDF_24:.*]] = arith.addf %{{.*}}#1, %{{.*}} | ||
// CHECK: scf.yield %[[ADDF_24]] | ||
// CHECK: } else { | ||
// CHECK: scf.yield %{{.*}} | ||
// CHECK: } | ||
// CHECK: scf.if %[[CMPI_22]] { | ||
// CHECK: memref.store %[[IF_23]], %{{.*}}[%[[ADDI_16]]] | ||
// CHECK: } else { | ||
// CHECK: } | ||
// CHECK: return | ||
func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) { | ||
%cf = arith.constant 1.0 : f32 | ||
scf.for %i0 = %lb to %ub step %step { | ||
|
@@ -781,6 +814,68 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, % | |
|
||
// ----- | ||
|
||
// NOEPILOGUE-LABEL: func.func @dynamic_loop_result | ||
// NOEPILOGUE: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}) | ||
// NOEPILOGUE: %[[SUBI_3:.*]] = arith.subi %{{.*}}, %{{.*}} | ||
// NOEPILOGUE: %[[CMPI_4:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_3]] | ||
// NOEPILOGUE: %[[ADDF_5:.*]] = arith.addf %[[ARG7]], %[[ARG6]] | ||
// NOEPILOGUE: %[[MULF_6:.*]] = arith.mulf %[[ADDF_5]], %{{.*}} | ||
// NOEPILOGUE: %[[ADDI_7:.*]] = arith.addi %[[ARG5]], %{{.*}} | ||
// NOEPILOGUE: %[[IF_8:.*]] = scf.if %[[CMPI_4]] | ||
// NOEPILOGUE: %[[LOAD_9:.*]] = memref.load %{{.*}}[%[[ADDI_7]]] | ||
// NOEPILOGUE: scf.yield %[[LOAD_9]] | ||
// NOEPILOGUE: } else { | ||
// NOEPILOGUE: scf.yield %{{.*}} | ||
// NOEPILOGUE: } | ||
// NOEPILOGUE: scf.yield %[[MULF_6]], %[[IF_8]] | ||
// NOEPILOGUE: } | ||
// NOEPILOGUE: memref.store %{{.*}}#0, %{{.*}}[%{{.*}}] | ||
|
||
// Check for predicated epilogue for dynamic loop. | ||
// CHECK-LABEL: func.func @dynamic_loop_result | ||
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}) | ||
// CHECK: %[[ADDF_13:.*]] = arith.addf %[[ARG7]], %[[ARG6]] | ||
// CHECK: %[[MULF_14:.*]] = arith.mulf %[[ADDF_13]], %{{.*}} | ||
// CHECK: %[[ADDI_15:.*]] = arith.addi %[[ARG5]], %{{.*}} | ||
// CHECK: %[[LOAD_16:.*]] = memref.load %{{.*}}[%[[ADDI_15]]] | ||
// CHECK: scf.yield %[[MULF_14]], %[[LOAD_16]] | ||
// CHECK: } | ||
// CHECK: %[[SUBI_4:.*]] = arith.subi %{{.*}}, %{{.*}} | ||
// CHECK: %[[ADDI_5:.*]] = arith.addi %[[SUBI_4]], %{{.*}} | ||
// CHECK: %[[ADDI_6:.*]] = arith.addi %[[ADDI_5]], %{{.*}}-1 | ||
// CHECK: %[[DIVUI_7:.*]] = arith.divui %[[ADDI_6]], %{{.*}} | ||
// CHECK: %[[ADDI_8:.*]] = arith.addi %[[DIVUI_7]], %{{.*}}-1 | ||
// CHECK: %[[CMPI_9:.*]] = arith.cmpi sge, %[[ADDI_8]], %{{.*}} | ||
// CHECK: %[[IF_10:.*]] = scf.if %[[CMPI_9]] | ||
// CHECK: %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0 | ||
// CHECK: scf.yield %[[ADDF_13]] | ||
// CHECK: } else { | ||
// CHECK: scf.yield %{{.*}} | ||
// CHECK: } | ||
// CHECK: %[[IF_11:.*]] = scf.if %[[CMPI_9]] | ||
// CHECK: %[[MULF_13:.*]] = arith.mulf %[[IF_10]], %{{.*}} | ||
// CHECK: scf.yield %[[MULF_13]] | ||
// CHECK: } else { | ||
// CHECK: scf.yield %{{.*}} | ||
// CHECK: } | ||
// CHECK: %[[SELECT_12:.*]] = arith.select %[[CMPI_9]], %[[IF_11]], %{{.*}}#0 | ||
// CHECK: memref.store %[[SELECT_12]], %{{.*}}[%{{.*}}] | ||
func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) { | ||
%cf0 = arith.constant 1.0 : f32 | ||
%cf1 = arith.constant 33.0 : f32 | ||
%cst = arith.constant 0 : index | ||
%res:1 = scf.for %i0 = %lb to %ub step %step iter_args (%arg0 = %cf0) -> (f32) { | ||
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32> | ||
%A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32 | ||
%A2_elem = arith.mulf %A1_elem, %cf1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32 | ||
scf.yield %A2_elem : f32 | ||
} { __test_pipelining_loop__ } | ||
memref.store %res#0, %result[%cst] : memref<?xf32> | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: yield_constant_loop( | ||
// CHECK-SAME: %[[A:.*]]: memref<?xf32>) -> f32 { | ||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a test to excercise this case?