Skip to content

Commit 19e068b

Browse files
authored
[MLIR][SCF] Handle more cases in pipelining transform (#74007)
-Fix case where an op is scheduled in stage 0 and used with a distance of 1 -Fix case where we don't peel the epilogue and a value not part of the last stage is used outside the loop.
1 parent b6d0ee0 commit 19e068b

File tree

2 files changed

+118
-16
lines changed

2 files changed

+118
-16
lines changed

mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ struct LoopPipelinerInternal {
6161
/// `idx` of `key` in the epilogue.
6262
void setValueMapping(Value key, Value el, int64_t idx);
6363

64+
/// Return the defining op of the given value, if the Value is an argument of
65+
/// the loop return the associated defining op in the loop and its distance to
66+
/// the Value.
67+
std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
68+
6469
public:
6570
/// Initalize the information for the given `op`, return true if it
6671
/// satisfies the pre-condition to apply pipelining.
@@ -240,11 +245,12 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
240245
unsigned stage = stages[op];
241246

242247
auto analyzeOperand = [&](OpOperand &operand) {
243-
Operation *def = operand.get().getDefiningOp();
248+
auto [def, distance] = getDefiningOpAndDistance(operand.get());
244249
if (!def)
245250
return;
246251
auto defStage = stages.find(def);
247-
if (defStage == stages.end() || defStage->second == stage)
252+
if (defStage == stages.end() || defStage->second == stage ||
253+
defStage->second == stage + distance)
248254
return;
249255
assert(stage > defStage->second);
250256
LiverangeInfo &info = crossStageValues[operand.get()];
@@ -261,6 +267,25 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
261267
return crossStageValues;
262268
}
263269

270+
std::pair<Operation *, int64_t>
271+
LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
272+
int64_t distance = 0;
273+
if (auto arg = dyn_cast<BlockArgument>(value)) {
274+
if (arg.getOwner() != forOp.getBody())
275+
return {nullptr, 0};
276+
// Ignore induction variable.
277+
if (arg.getArgNumber() == 0)
278+
return {nullptr, 0};
279+
distance++;
280+
value =
281+
forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
282+
}
283+
Operation *def = value.getDefiningOp();
284+
if (!def)
285+
return {nullptr, 0};
286+
return {def, distance};
287+
}
288+
264289
scf::ForOp LoopPipelinerInternal::createKernelLoop(
265290
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
266291
&crossStageValues,
@@ -366,10 +391,9 @@ LogicalResult LoopPipelinerInternal::createKernel(
366391
rewriter.setInsertionPointAfter(newOp);
367392
continue;
368393
}
369-
auto arg = dyn_cast<BlockArgument>(operand->get());
394+
Value source = operand->get();
395+
auto arg = dyn_cast<BlockArgument>(source);
370396
if (arg && arg.getOwner() == forOp.getBody()) {
371-
// If the value is a loop carried value coming from stage N + 1 remap,
372-
// it will become a direct use.
373397
Value ret = forOp.getBody()->getTerminator()->getOperand(
374398
arg.getArgNumber() - 1);
375399
Operation *dep = ret.getDefiningOp();
@@ -378,15 +402,19 @@ LogicalResult LoopPipelinerInternal::createKernel(
378402
auto stageDep = stages.find(dep);
379403
if (stageDep == stages.end() || stageDep->second == useStage)
380404
continue;
381-
assert(stageDep->second == useStage + 1);
382-
nestedNewOp->setOperand(operand->getOperandNumber(),
383-
mapping.lookupOrDefault(ret));
384-
continue;
405+
// If the value is a loop carried value coming from stage N + 1 remap,
406+
// it will become a direct use.
407+
if (stageDep->second == useStage + 1) {
408+
nestedNewOp->setOperand(operand->getOperandNumber(),
409+
mapping.lookupOrDefault(ret));
410+
continue;
411+
}
412+
source = ret;
385413
}
386414
// For operands defined in a previous stage we need to remap it to use
387415
// the correct region argument. We look for the right version of the
388416
// Value based on the stage where it is used.
389-
Operation *def = operand->get().getDefiningOp();
417+
Operation *def = source.getDefiningOp();
390418
if (!def)
391419
continue;
392420
auto stageDef = stages.find(def);
@@ -418,9 +446,29 @@ LogicalResult LoopPipelinerInternal::createKernel(
418446
// We create a mapping between original values and the associated loop
419447
// returned values that will be needed by the epilogue.
420448
llvm::SmallVector<Value> yieldOperands;
421-
for (Value retVal : forOp.getBody()->getTerminator()->getOperands()) {
422-
yieldOperands.push_back(mapping.lookupOrDefault(retVal));
449+
for (OpOperand &yieldOperand :
450+
forOp.getBody()->getTerminator()->getOpOperands()) {
451+
Value source = mapping.lookupOrDefault(yieldOperand.get());
452+
// When we don't peel the epilogue and the yield value is used outside the
453+
// loop we need to make sure we return the version from numStages -
454+
// defStage.
455+
if (!peelEpilogue &&
456+
!forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) {
457+
Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first;
458+
if (def) {
459+
auto defStage = stages.find(def);
460+
if (defStage != stages.end() && defStage->second < maxStage) {
461+
Value pred = predicates[defStage->second];
462+
source = rewriter.create<arith::SelectOp>(
463+
pred.getLoc(), pred, source,
464+
newForOp.getBody()
465+
->getArguments()[yieldOperand.getOperandNumber() + 1]);
466+
}
467+
}
468+
}
469+
yieldOperands.push_back(source);
423470
}
471+
424472
for (auto &it : crossStageValues) {
425473
int64_t version = maxStage - it.second.lastUseStage + 1;
426474
unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
@@ -444,9 +492,11 @@ LogicalResult LoopPipelinerInternal::createKernel(
444492
Operation *def = retVal.value().getDefiningOp();
445493
assert(def && "Only support loop carried dependencies of distance 1");
446494
unsigned defStage = stages[def];
447-
setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
448-
newForOp->getResult(retVal.index()),
449-
maxStage - defStage + 1);
495+
if (defStage > 0) {
496+
setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
497+
newForOp->getResult(retVal.index()),
498+
maxStage - defStage + 1);
499+
}
450500
}
451501
rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
452502
return success();

mlir/test/Dialect/SCF/loop-pipelining.mlir

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,4 +670,56 @@ func.func @backedge_mix_order(%A: memref<?xf32>) -> f32 {
670670
scf.yield %A3_elem : f32
671671
} { __test_pipelining_loop__ }
672672
return %r : f32
673-
}
673+
}
674+
675+
// -----
676+
677+
// CHECK-LABEL: @distance_1_use
678+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
679+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
680+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
681+
// Prologue:
682+
// CHECK: %[[L0:.+]] = memref.load %{{.*}}[%[[C0]]] : memref<?xf32>
683+
// CHECK: %[[L1:.+]] = memref.load %{{.*}}[%[[C1]]] : memref<?xf32>
684+
// CHECK: %[[R:.+]]:5 = scf.for {{.*}} iter_args(%[[IDX0:.+]] = %[[C2]], %[[L2:.+]] = %[[L0]], %[[L3:.+]] = %[[L1]]
685+
// CHECK: %[[L4:.+]] = memref.load %{{.*}}[%[[IDX0]]] : memref<?xf32>
686+
// CHECK: %[[IDX1:.+]] = arith.addi %[[IDX0]], %[[C1]] : index
687+
// CHECK: memref.store %[[L2]]
688+
// CHECK: scf.yield %[[IDX1]], %[[L3]], %[[L4]]
689+
func.func @distance_1_use(%A: memref<?xf32>, %result: memref<?xf32>) {
690+
%c0 = arith.constant 0 : index
691+
%c1 = arith.constant 1 : index
692+
%c4 = arith.constant 4 : index
693+
%cf = arith.constant 1.0 : f32
694+
%r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%idx = %c0) -> (index) {
695+
%A_elem = memref.load %A[%idx] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : memref<?xf32>
696+
%idx1 = arith.addi %idx, %c1 { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : index
697+
memref.store %A_elem, %result[%idx] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
698+
scf.yield %idx1 : index
699+
} { __test_pipelining_loop__ }
700+
return
701+
}
702+
703+
// -----
704+
705+
// NOEPILOGUE-LABEL: stage_0_value_escape(
706+
func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>) {
707+
%c0 = arith.constant 0 : index
708+
%c1 = arith.constant 1 : index
709+
%c4 = arith.constant 4 : index
710+
%cf = arith.constant 1.0 : f32
711+
// NOEPILOGUE: %[[C3:.+]] = arith.constant 3 : index
712+
// NOEPILOGUE: %[[A:.+]] = arith.addf
713+
// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[A]],
714+
// NOEPILOGUE: %[[C:.+]] = arith.cmpi slt, %[[IV]], %[[C3]] : index
715+
// NOEPILOGUE: %[[S:.+]] = arith.select %[[C]], %{{.+}}, %[[ARG]] : f32
716+
// NOEPILOGUE: scf.yield %[[S]]
717+
%r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
718+
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : memref<?xf32>
719+
%A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
720+
memref.store %A1_elem, %result[%c0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
721+
scf.yield %A1_elem : f32
722+
} { __test_pipelining_loop__ }
723+
memref.store %r, %result[%c1] : memref<?xf32>
724+
return
725+
}

0 commit comments

Comments
 (0)