Skip to content

Commit e66f97e

Browse files
authored
[mlir] Fix loop pipelining when the operand of yield is not defined in the loop body (llvm#75423)
1 parent 042a2e8 commit e66f97e

File tree

2 files changed

+77
-21
lines changed

2 files changed

+77
-21
lines changed

mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ struct LoopPipelinerInternal {
9090
RewriterBase &rewriter);
9191
/// Emits the epilogue, this creates `maxStage - 1` part which will contain
9292
/// operations from stages [i; maxStage], where i is the part index.
93-
llvm::SmallVector<Value> emitEpilogue(RewriterBase &rewriter);
93+
void emitEpilogue(RewriterBase &rewriter,
94+
llvm::SmallVector<Value> &returnValues);
9495
};
9596

9697
bool LoopPipelinerInternal::initializeLoopInfo(
@@ -175,15 +176,18 @@ bool LoopPipelinerInternal::initializeLoopInfo(
175176
}
176177
}
177178

178-
// Only support loop carried dependency with a distance of 1. This means the
179-
// source of all the scf.yield operands needs to be defined by operations in
180-
// the loop.
179+
// Support only loop-carried dependencies with a distance of one iteration or
180+
// those defined outside of the loop. This means that any dependency within a
181+
// loop should either be on the immediately preceding iteration, the current
182+
// iteration, or on variables whose values are set before entering the loop.
181183
if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
182184
[this](Value operand) {
183185
Operation *def = operand.getDefiningOp();
184-
return !def || !stages.contains(def);
186+
return !def ||
187+
(!stages.contains(def) && forOp->isAncestor(def));
185188
})) {
186-
LDBG("--only support loop carried dependency with a distance of 1 -> BAIL");
189+
LDBG("--only support loop carried dependency with a distance of 1 or "
190+
"defined outside of the loop -> BAIL");
187191
return false;
188192
}
189193
annotateFn = options.annotateFn;
@@ -341,12 +345,17 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
341345
for (const auto &retVal :
342346
llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
343347
Operation *def = retVal.value().getDefiningOp();
344-
assert(def && "Only support loop carried dependencies of distance 1");
345-
unsigned defStage = stages[def];
346-
Value valueVersion = valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
347-
[maxStage - defStage];
348-
assert(valueVersion);
349-
newLoopArg.push_back(valueVersion);
348+
assert(def && "Only support loop carried dependencies of distance of 1 or "
349+
"outside the loop");
350+
auto defStage = stages.find(def);
351+
if (defStage != stages.end()) {
352+
Value valueVersion =
353+
valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
354+
[maxStage - defStage->second];
355+
assert(valueVersion);
356+
newLoopArg.push_back(valueVersion);
357+
} else
358+
newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]);
350359
}
351360
for (auto escape : crossStageValues) {
352361
LiverangeInfo &info = escape.second;
@@ -551,21 +560,25 @@ LogicalResult LoopPipelinerInternal::createKernel(
551560
for (const auto &retVal :
552561
llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
553562
Operation *def = retVal.value().getDefiningOp();
554-
assert(def && "Only support loop carried dependencies of distance 1");
555-
unsigned defStage = stages[def];
556-
if (defStage > 0) {
563+
assert(def && "Only support loop carried dependencies of distance of 1 or "
564+
"defined outside the loop");
565+
auto defStage = stages.find(def);
566+
if (defStage == stages.end()) {
567+
for (unsigned int stage = 1; stage <= maxStage; stage++)
568+
setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
569+
retVal.value(), stage);
570+
} else if (defStage->second > 0) {
557571
setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
558572
newForOp->getResult(retVal.index()),
559-
maxStage - defStage + 1);
573+
maxStage - defStage->second + 1);
560574
}
561575
}
562576
rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
563577
return success();
564578
}
565579

566-
llvm::SmallVector<Value>
567-
LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) {
568-
llvm::SmallVector<Value> returnValues(forOp->getNumResults());
580+
void LoopPipelinerInternal::emitEpilogue(
581+
RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
569582
// Emit different versions of the induction variable. They will be
570583
// removed by dead code if not used.
571584
for (int64_t i = 0; i < maxStage; i++) {
@@ -628,7 +641,6 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) {
628641
}
629642
}
630643
}
631-
return returnValues;
632644
}
633645

634646
void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
@@ -685,7 +697,7 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
685697
if (options.peelEpilogue) {
686698
// 4. Emit the epilogue after the new forOp.
687699
rewriter.setInsertionPointAfter(newForOp);
688-
returnValues = pipeliner.emitEpilogue(rewriter);
700+
pipeliner.emitEpilogue(rewriter, returnValues);
689701
}
690702
// 5. Erase the original loop and replace the uses with the epilogue output.
691703
if (forOp->getNumResults() > 0)

mlir/test/Dialect/SCF/loop-pipelining.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,3 +770,47 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
770770
} { __test_pipelining_loop__ }
771771
return
772772
}
773+
774+
// -----
775+
776+
// CHECK-LABEL: yield_constant_loop(
777+
// CHECK-SAME: %[[A:.*]]: memref<?xf32>) -> f32 {
778+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
779+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
780+
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
781+
// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
782+
// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
783+
// Prologue:
784+
// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
785+
// Kernel:
786+
// CHECK-NEXT: %[[L1:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]]
787+
// CHECK-SAME: step %[[C1]] iter_args(%[[ARG0:.*]] = %[[CST2]], %[[ARG1:.*]] = %[[L0]]) -> (f32, f32) {
788+
// CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG1]], %[[ARG0]] : f32
789+
// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[ADD0]], %[[CST0]] : f32
790+
// CHECK-NEXT: memref.store %[[MUL0]], %[[A]][%[[IV]]] : memref<?xf32>
791+
// CHECK-NEXT: %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index
792+
// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
793+
// CHECK-NEXT: scf.yield %[[CST0]], %[[L2]] : f32
794+
// CHECK-NEXT: }
795+
// Epilogue:
796+
// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[L1]]#1, %[[CST0]] : f32
797+
// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD1]], %[[CST0]] : f32
798+
// CHECK-NEXT: memref.store %[[MUL1]], %[[A]][%[[C3]]] : memref<?xf32>
799+
// CHECK-NEXT: return %[[L1]]#0 : f32
800+
801+
func.func @yield_constant_loop(%A: memref<?xf32>) -> f32 {
802+
%c0 = arith.constant 0 : index
803+
%c1 = arith.constant 1 : index
804+
%c4 = arith.constant 4 : index
805+
%cf0 = arith.constant 0.0 : f32
806+
%cf2 = arith.constant 2.0 : f32
807+
%r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf2) -> f32 {
808+
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref<?xf32>
809+
%A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
810+
%A2_elem = arith.mulf %cf0, %A1_elem { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
811+
memref.store %A2_elem, %A[%i0] { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 2 } : memref<?xf32>
812+
scf.yield %cf0: f32
813+
} { __test_pipelining_loop__ }
814+
return %r : f32
815+
}
816+

0 commit comments

Comments
 (0)