Skip to content

[MLIR][SCF] Handle more cases in pipelining transform #74007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 65 additions & 15 deletions mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ struct LoopPipelinerInternal {
/// `idx` of `key` in the epilogue.
void setValueMapping(Value key, Value el, int64_t idx);

/// Return the defining op of the given value, if the Value is an argument of
/// the loop return the associated defining op in the loop and its distance to
/// the Value.
std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);

public:
/// Initalize the information for the given `op`, return true if it
/// satisfies the pre-condition to apply pipelining.
Expand Down Expand Up @@ -240,11 +245,12 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
unsigned stage = stages[op];

auto analyzeOperand = [&](OpOperand &operand) {
Operation *def = operand.get().getDefiningOp();
auto [def, distance] = getDefiningOpAndDistance(operand.get());
if (!def)
return;
auto defStage = stages.find(def);
if (defStage == stages.end() || defStage->second == stage)
if (defStage == stages.end() || defStage->second == stage ||
defStage->second == stage + distance)
return;
assert(stage > defStage->second);
LiverangeInfo &info = crossStageValues[operand.get()];
Expand All @@ -261,6 +267,25 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
return crossStageValues;
}

std::pair<Operation *, int64_t>
LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
int64_t distance = 0;
if (auto arg = dyn_cast<BlockArgument>(value)) {
if (arg.getOwner() != forOp.getBody())
return {nullptr, 0};
// Ignore induction variable.
if (arg.getArgNumber() == 0)
return {nullptr, 0};
distance++;
value =
forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
}
Operation *def = value.getDefiningOp();
if (!def)
return {nullptr, 0};
return {def, distance};
}

scf::ForOp LoopPipelinerInternal::createKernelLoop(
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
&crossStageValues,
Expand Down Expand Up @@ -366,10 +391,9 @@ LogicalResult LoopPipelinerInternal::createKernel(
rewriter.setInsertionPointAfter(newOp);
continue;
}
auto arg = dyn_cast<BlockArgument>(operand->get());
Value source = operand->get();
auto arg = dyn_cast<BlockArgument>(source);
if (arg && arg.getOwner() == forOp.getBody()) {
// If the value is a loop carried value coming from stage N + 1 remap,
// it will become a direct use.
Value ret = forOp.getBody()->getTerminator()->getOperand(
arg.getArgNumber() - 1);
Operation *dep = ret.getDefiningOp();
Expand All @@ -378,15 +402,19 @@ LogicalResult LoopPipelinerInternal::createKernel(
auto stageDep = stages.find(dep);
if (stageDep == stages.end() || stageDep->second == useStage)
continue;
assert(stageDep->second == useStage + 1);
nestedNewOp->setOperand(operand->getOperandNumber(),
mapping.lookupOrDefault(ret));
continue;
// If the value is a loop carried value coming from stage N + 1 remap,
// it will become a direct use.
if (stageDep->second == useStage + 1) {
nestedNewOp->setOperand(operand->getOperandNumber(),
mapping.lookupOrDefault(ret));
continue;
}
source = ret;
}
// For operands defined in a previous stage we need to remap it to use
// the correct region argument. We look for the right version of the
// Value based on the stage where it is used.
Operation *def = operand->get().getDefiningOp();
Operation *def = source.getDefiningOp();
if (!def)
continue;
auto stageDef = stages.find(def);
Expand Down Expand Up @@ -418,9 +446,29 @@ LogicalResult LoopPipelinerInternal::createKernel(
// We create a mapping between original values and the associated loop
// returned values that will be needed by the epilogue.
llvm::SmallVector<Value> yieldOperands;
for (Value retVal : forOp.getBody()->getTerminator()->getOperands()) {
yieldOperands.push_back(mapping.lookupOrDefault(retVal));
for (OpOperand &yieldOperand :
forOp.getBody()->getTerminator()->getOpOperands()) {
Value source = mapping.lookupOrDefault(yieldOperand.get());
// When we don't peel the epilogue and the yield value is used outside the
// loop we need to make sure we return the version from numStages -
// defStage.
if (!peelEpilogue &&
!forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) {
Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first;
if (def) {
auto defStage = stages.find(def);
if (defStage != stages.end() && defStage->second < maxStage) {
Value pred = predicates[defStage->second];
source = rewriter.create<arith::SelectOp>(
pred.getLoc(), pred, source,
newForOp.getBody()
->getArguments()[yieldOperand.getOperandNumber() + 1]);
}
}
}
yieldOperands.push_back(source);
}

for (auto &it : crossStageValues) {
int64_t version = maxStage - it.second.lastUseStage + 1;
unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
Expand All @@ -444,9 +492,11 @@ LogicalResult LoopPipelinerInternal::createKernel(
Operation *def = retVal.value().getDefiningOp();
assert(def && "Only support loop carried dependencies of distance 1");
unsigned defStage = stages[def];
setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
newForOp->getResult(retVal.index()),
maxStage - defStage + 1);
if (defStage > 0) {
setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
newForOp->getResult(retVal.index()),
maxStage - defStage + 1);
}
}
rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
return success();
Expand Down
54 changes: 53 additions & 1 deletion mlir/test/Dialect/SCF/loop-pipelining.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -670,4 +670,56 @@ func.func @backedge_mix_order(%A: memref<?xf32>) -> f32 {
scf.yield %A3_elem : f32
} { __test_pipelining_loop__ }
return %r : f32
}
}

// -----

// CHECK-LABEL: @distance_1_use
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// Prologue:
// CHECK: %[[L0:.+]] = memref.load %{{.*}}[%[[C0]]] : memref<?xf32>
// CHECK: %[[L1:.+]] = memref.load %{{.*}}[%[[C1]]] : memref<?xf32>
// CHECK: %[[R:.+]]:5 = scf.for {{.*}} iter_args(%[[IDX0:.+]] = %[[C2]], %[[L2:.+]] = %[[L0]], %[[L3:.+]] = %[[L1]]
// CHECK: %[[L4:.+]] = memref.load %{{.*}}[%[[IDX0]]] : memref<?xf32>
// CHECK: %[[IDX1:.+]] = arith.addi %[[IDX0]], %[[C1]] : index
// CHECK: memref.store %[[L2]]
// CHECK: scf.yield %[[IDX1]], %[[L3]], %[[L4]]
func.func @distance_1_use(%A: memref<?xf32>, %result: memref<?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%cf = arith.constant 1.0 : f32
%r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%idx = %c0) -> (index) {
%A_elem = memref.load %A[%idx] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : memref<?xf32>
%idx1 = arith.addi %idx, %c1 { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : index
memref.store %A_elem, %result[%idx] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
scf.yield %idx1 : index
} { __test_pipelining_loop__ }
return
}

// -----

// NOEPILOGUE-LABEL: stage_0_value_escape(
func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%cf = arith.constant 1.0 : f32
// NOEPILOGUE: %[[C3:.+]] = arith.constant 3 : index
// NOEPILOGUE: %[[A:.+]] = arith.addf
// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[A]],
// NOEPILOGUE: %[[C:.+]] = arith.cmpi slt, %[[IV]], %[[C3]] : index
// NOEPILOGUE: %[[S:.+]] = arith.select %[[C]], %{{.+}}, %[[ARG]] : f32
// NOEPILOGUE: scf.yield %[[S]]
%r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : memref<?xf32>
%A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
memref.store %A1_elem, %result[%c0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
scf.yield %A1_elem : f32
} { __test_pipelining_loop__ }
memref.store %r, %result[%c1] : memref<?xf32>
return
}