Skip to content

[MLIR][SCF] Add support for pipelining dynamic loops #74350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ struct PipeliningOption {
/// lambda to generate the predicated version of operations.
bool peelEpilogue = true;

/// Control whether the transformation checks that the number of iterations is
/// greater or equal to the number of stages and skip the transformation if
/// this is not the case. If the loop is dynamic and this is set to true and
/// the loop bounds are not static the pipeliner will have to predicate
/// operations in the the prologue/epilogue.
bool supportDynamicLoops = false;

// Callback to predicate operations when the prologue or epilogue are not
// peeled. This takes the original operation, an i1 predicate value and the
// pattern rewriter. It is expected to replace the given operation with
Expand Down
137 changes: 107 additions & 30 deletions mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ struct LoopPipelinerInternal {
unsigned maxStage = 0;
DenseMap<Operation *, unsigned> stages;
std::vector<Operation *> opOrder;
int64_t ub;
int64_t lb;
int64_t step;
Value ub;
Value lb;
Value step;
bool dynamicLoop;
PipeliningOption::AnnotationlFnType annotateFn = nullptr;
bool peelEpilogue;
PipeliningOption::PredicateOpFn predicateFn = nullptr;
Expand Down Expand Up @@ -96,25 +97,41 @@ bool LoopPipelinerInternal::initializeLoopInfo(
ForOp op, const PipeliningOption &options) {
LDBG("Start initializeLoopInfo");
forOp = op;
auto upperBoundCst =
forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
auto lowerBoundCst =
forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
auto stepCst = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
ub = forOp.getUpperBound();
lb = forOp.getLowerBound();
step = forOp.getStep();

dynamicLoop = true;
auto upperBoundCst = getConstantIntValue(ub);
auto lowerBoundCst = getConstantIntValue(lb);
auto stepCst = getConstantIntValue(step);
if (!upperBoundCst || !lowerBoundCst || !stepCst) {
LDBG("--no constant bounds or step -> BAIL");
return false;
if (!options.supportDynamicLoops) {
LDBG("--dynamic loop not supported -> BAIL");
return false;
}
} else {
int64_t ubImm = upperBoundCst.value();
int64_t lbImm = lowerBoundCst.value();
int64_t stepImm = stepCst.value();
int64_t numIteration = ceilDiv(ubImm - lbImm, stepImm);
if (numIteration > maxStage) {
dynamicLoop = false;
} else if (!options.supportDynamicLoops) {
LDBG("--fewer loop iterations than pipeline stages -> BAIL");
return false;
}
}
ub = upperBoundCst.value();
lb = lowerBoundCst.value();
step = stepCst.value();
peelEpilogue = options.peelEpilogue;
predicateFn = options.predicateFn;
if (!peelEpilogue && predicateFn == nullptr) {
if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
LDBG("--no epilogue or predicate set -> BAIL");
return false;
}
int64_t numIteration = ceilDiv(ub - lb, step);
if (dynamicLoop && peelEpilogue) {
LDBG("--dynamic loop doesn't support epilogue yet -> BAIL");
return false;
}
std::vector<std::pair<Operation *, unsigned>> schedule;
options.getScheduleFn(forOp, schedule);
if (schedule.empty()) {
Expand All @@ -128,10 +145,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
stages[opSchedule.first] = opSchedule.second;
opOrder.push_back(opSchedule.first);
}
if (numIteration <= maxStage) {
LDBG("--fewer loop iterations than pipeline stages -> BAIL");
return false;
}

// All operations need to have a stage.
for (Operation &op : forOp.getBody()->without_terminator()) {
Expand Down Expand Up @@ -204,10 +217,31 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
setValueMapping(arg, operand.get(), 0);
}
auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
Location loc = forOp.getLoc();
SmallVector<Value> predicates(maxStage);
for (int64_t i = 0; i < maxStage; i++) {
if (dynamicLoop) {
Type t = ub.getType();
// pred = ub > lb + (i * step)
Value iv = rewriter.create<arith::AddIOp>(
loc, lb,
rewriter.create<arith::MulIOp>(
loc, step,
rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(t, i))));
predicates[i] = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, iv, ub);
}

// special handling for induction variable as the increment is implicit.
Value iv =
rewriter.create<arith::ConstantIndexOp>(forOp.getLoc(), lb + i * step);
// iv = lb + i * step
Type t = lb.getType();
Value iv = rewriter.create<arith::AddIOp>(
loc, lb,
rewriter.create<arith::MulIOp>(
loc, step,
rewriter.create<arith::ConstantOp>(loc,
rewriter.getIntegerAttr(t, i))));
setValueMapping(forOp.getInductionVar(), iv, i);
for (Operation *op : opOrder) {
if (stages[op] > i)
Expand All @@ -220,6 +254,12 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
newOperand->set(replacement);
}
});
int predicateIdx = i - stages[op];
if (predicates[predicateIdx]) {
newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
assert(newOp && "failed to predicate op.");
}
rewriter.setInsertionPointAfter(newOp);
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
Expand Down Expand Up @@ -326,9 +366,16 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
// `numStages - 1` iterations. Then we adjust the upper bound to remove those
// iterations.
Value newUb = forOp.getUpperBound();
if (peelEpilogue)
newUb = rewriter.create<arith::ConstantIndexOp>(forOp.getLoc(),
ub - maxStage * step);
if (peelEpilogue) {
Type t = ub.getType();
Location loc = forOp.getLoc();
// newUb = ub - maxStage * step
Value maxStageValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(t, maxStage));
Value maxStageByStep =
rewriter.create<arith::MulIOp>(loc, step, maxStageValue);
newUb = rewriter.create<arith::SubIOp>(loc, ub, maxStageByStep);
}
auto newForOp =
rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
forOp.getStep(), newLoopArg);
Expand Down Expand Up @@ -358,9 +405,17 @@ LogicalResult LoopPipelinerInternal::createKernel(
SmallVector<Value> predicates(maxStage + 1, nullptr);
if (!peelEpilogue) {
// Create a predicate for each stage except the last stage.
Location loc = newForOp.getLoc();
Type t = ub.getType();
for (unsigned i = 0; i < maxStage; i++) {
Value c = rewriter.create<arith::ConstantIndexOp>(
newForOp.getLoc(), ub - (maxStage - i) * step);
// c = ub - (maxStage - i) * step
Value c = rewriter.create<arith::SubIOp>(
loc, ub,
rewriter.create<arith::MulIOp>(
loc, step,
rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i)))));

Value pred = rewriter.create<arith::CmpIOp>(
newForOp.getLoc(), arith::CmpIPredicate::slt,
newForOp.getInductionVar(), c);
Expand All @@ -383,8 +438,14 @@ LogicalResult LoopPipelinerInternal::createKernel(
// version incremented based on the stage where it is used.
if (operand->get() == forOp.getInductionVar()) {
rewriter.setInsertionPoint(newOp);
Value offset = rewriter.create<arith::ConstantIndexOp>(
forOp.getLoc(), (maxStage - stages[op]) * step);

// offset = (maxStage - stages[op]) * step
Type t = step.getType();
Value offset = rewriter.create<arith::MulIOp>(
forOp.getLoc(), step,
rewriter.create<arith::ConstantOp>(
forOp.getLoc(),
rewriter.getIntegerAttr(t, maxStage - stages[op])));
Value iv = rewriter.create<arith::AddIOp>(
forOp.getLoc(), newForOp.getInductionVar(), offset);
nestedNewOp->setOperand(operand->getOperandNumber(), iv);
Expand Down Expand Up @@ -508,8 +569,24 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) {
// Emit different versions of the induction variable. They will be
// removed by dead code if not used.
for (int64_t i = 0; i < maxStage; i++) {
Value newlastIter = rewriter.create<arith::ConstantIndexOp>(
forOp.getLoc(), lb + step * ((((ub - 1) - lb) / step) - i));
Location loc = forOp.getLoc();
Type t = lb.getType();
Value minusOne =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
// number of iterations = ((ub - 1) - lb) / step
Value totalNumIteration = rewriter.create<arith::DivUIOp>(
loc,
rewriter.create<arith::SubIOp>(
loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
step);
// newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
Value minusI =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
Value newlastIter = rewriter.create<arith::AddIOp>(
loc, lb,
rewriter.create<arith::MulIOp>(
loc, step,
rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI)));
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
}
// Emit `maxStage - 1` epilogue part that includes operations from stages
Expand Down
26 changes: 11 additions & 15 deletions mlir/test/Dialect/NVGPU/transform-pipeline-shared.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s
// RUN: mlir-opt %s --transform-interpreter -canonicalize --split-input-file --verify-diagnostics | FileCheck %s

func.func @simple_depth_2_unpeeled(%global: memref<?xf32>, %result: memref<?xf32> ) {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -78,15 +78,19 @@ module attributes {transform.with_named_sequence} {

// CHECK-LABEL: @async_depth_2_predicated
// CHECK-SAME: %[[GLOBAL:.+]]: memref
func.func @async_depth_2_predicated(%global: memref<?xf32>) {
func.func @async_depth_2_predicated(%global: memref<?xf32>, %alloc_size: index) {
%c0 = arith.constant 0 : index
%c98 = arith.constant 98 : index
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
// CHECK: %[[C4:.+]] = arith.constant 4
// CHECK-DAG: %[[C4:.+]] = arith.constant 4
// CHECK-DAG: %[[C90:.+]] = arith.constant 90
// CHECK-DAG: %[[C96:.+]] = arith.constant 96
// CHECK-DAG: %[[C8:.+]] = arith.constant 8
// CHECK-DAG: %[[C2:.+]] = arith.constant 2
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
%c4 = arith.constant 4 : index
// CHECK: %[[SHARED:.+]] = memref.alloc{{.*}} #gpu.address_space<workgroup>
%shared = memref.alloc(%c200) : memref<?xf32, #gpu.address_space<workgroup>>
%shared = memref.alloc(%alloc_size) : memref<?xf32, #gpu.address_space<workgroup>>
%c0f = arith.constant 0.0 : f32
// CHECK: %[[TOKEN0:.+]] = nvgpu.device_async_copy
// CHECK: %[[TOKEN1:.+]] = nvgpu.device_async_copy
Expand All @@ -95,16 +99,11 @@ func.func @async_depth_2_predicated(%global: memref<?xf32>) {
// CHECK-SAME: %[[ITER_ARG1:.+]] = %[[TOKEN1]]
scf.for %i = %c0 to %c98 step %c4 {
// Condition for the predication "select" below.
// CHECK: %[[C90:.+]] = arith.constant 90
// CHECK: %[[CMP0:.+]] = arith.cmpi slt, %[[I]], %[[C90]]
// CHECK: nvgpu.device_async_wait %[[ITER_ARG0]] {numGroups = 1

// Original "select" with updated induction variable.
// CHECK: %[[C96:.+]] = arith.constant 96
// CHECK: %[[C8:.+]] = arith.constant 8
// CHECK: %[[I_PLUS_8:.+]] = arith.addi %[[I]], %[[C8]]
// CHECK: %[[CMP1:.+]] = arith.cmpi slt, %[[I_PLUS_8]], %[[C96]]
// CHECK: %[[C2:.+]] = arith.constant 2
// CHECK: %[[SELECTED0:.+]] = arith.select %[[CMP1]], %[[C4]], %[[C2]]
%c96 = arith.constant 96 : index
%cond = arith.cmpi slt, %i, %c96 : index
Expand All @@ -113,14 +112,11 @@ func.func @async_depth_2_predicated(%global: memref<?xf32>) {

// Updated induction variables (two more) for the device_async_copy below.
// These are generated repeatedly by the pipeliner.
// CHECK: %[[C8_2:.+]] = arith.constant 8
// CHECK: %[[I_PLUS_8_2:.+]] = arith.addi %[[I]], %[[C8_2]]
// CHECK: %[[C8_3:.+]] = arith.constant 8
// CHECK: %[[I_PLUS_8_3:.+]] = arith.addi %[[I]], %[[C8_3]]
// CHECK: %[[I_PLUS_8_2:.+]] = arith.addi %[[I]], %[[C8]]
// CHECK: %[[I_PLUS_8_3:.+]] = arith.addi %[[I]], %[[C8]]

// The second "select" is generated by predication and selects 0 for
// the two last iterations.
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[SELECTED1:.+]] = arith.select %[[CMP0]], %[[SELECTED0]], %[[C0]]
// CHECK: %[[ASYNC_TOKEN:.+]] = nvgpu.device_async_copy %[[GLOBAL]][%[[I_PLUS_8_3]]], %[[SHARED]][%[[I_PLUS_8_2]]], 4, %[[SELECTED1]]
%token = nvgpu.device_async_copy %global[%i], %shared[%i], 4, %read_size
Expand Down
47 changes: 47 additions & 0 deletions mlir/test/Dialect/SCF/loop-pipelining.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -723,3 +723,50 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>) {
memref.store %r, %result[%c1] : memref<?xf32>
return
}

// -----

// NOEPILOGUE-LABEL: dynamic_loop(
// NOEPILOGUE-SAME: %[[A:.*]]: memref<?xf32>, %[[R:.*]]: memref<?xf32>, %[[LB:.+]]: index, %[[UB:.+]]: index, %[[STEP:.+]]: index) {
// NOEPILOGUE-DAG: %[[C2:.+]] = arith.constant 2 : index
// NOEPILOGUE-DAG: %[[CSTF:.+]] = arith.constant 1.000000e+00 : f32
// Prologue:
// NOEPILOGUE: %[[P_I0:.+]] = arith.cmpi slt, %[[LB]], %[[UB]] : index
// NOEPILOGUE: %[[L0:.+]] = scf.if %[[P_I0]] -> (f32) {
// NOEPILOGUE-NEXT: memref.load %[[A]][%[[LB]]] : memref<?xf32>
// NOEPILOGUE: %[[IV1:.+]] = arith.addi %[[LB]], %[[STEP]] : index
// NOEPILOGUE: %[[P_I1:.+]] = arith.cmpi slt, %[[IV1]], %[[UB]] : index
// NOEPILOGUE: %[[IV1_2:.+]] = arith.addi %[[LB]], %[[STEP]] : index
// NOEPILOGUE: %[[V0:.+]] = scf.if %[[P_I0]] -> (f32) {
// NOEPILOGUE-NEXT: arith.addf %[[L0]], %[[CSTF]] : f32
// NOEPILOGUE: %[[L1:.+]] = scf.if %[[P_I1]] -> (f32) {
// NOEPILOGUE-NEXT: memref.load %[[A]][%[[IV1_2]]] : memref<?xf32>
// NOEPILOGUE: scf.for %[[IV2:.+]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[V1:.+]] = %[[V0]], %[[L2:.+]] = %[[L1]]) -> (f32, f32) {
// NOEPILOGUE-DAG: %[[S2:.+]] = arith.muli %[[STEP]], %[[C2]] : index
// NOEPILOGUE-DAG: %[[IT2:.+]] = arith.subi %[[UB]], %[[S2]] : index
// NOEPILOGUE-DAG: %[[P_I2:.+]] = arith.cmpi slt, %[[IV2]], %[[IT2]] : index
// NOEPILOGUE-DAG: %[[IT3:.+]] = arith.subi %[[UB]], %[[STEP]] : index
// NOEPILOGUE-DAG: %[[P_I3:.+]] = arith.cmpi slt, %[[IV2]], %[[IT3]] : index
// NOEPILOGUE: memref.store %[[V1]], %[[R]][%[[IV2]]] : memref<?xf32>
// NOEPILOGUE: %[[V2:.+]] = scf.if %[[P_I3]] -> (f32) {
// NOEPILOGUE: arith.addf %[[L2]], %[[CSTF]] : f32
// NOEPILOGUE: %[[IT4:.+]] = arith.muli %[[STEP]], %[[C2]] : index
// NOEPILOGUE: %[[IV3:.+]] = arith.addi %[[IV2]], %[[IT4]] : index
// NOEPILOGUE: %[[L3:.+]] = scf.if %[[P_I2]] -> (f32) {
// NOEPILOGUE: memref.load %[[A]][%[[IV3]]] : memref<?xf32>
// NOEPILOGUE: scf.yield %[[V2]], %[[L3]] : f32, f32

// In case dynamic loop pipelining is off check that the transformation didn't
// apply.
// CHECK-LABEL: dynamic_loop(
// CHECK-NOT: memref.load
// CHECK: scf.for
func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
%cf = arith.constant 1.0 : f32
scf.for %i0 = %lb to %ub step %step {
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
%A1_elem = arith.addf %A_elem, %cf { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
memref.store %A1_elem, %result[%i0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 0 } : memref<?xf32>
} { __test_pipelining_loop__ }
return
}
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ struct TestSCFPipeliningPass
if (annotatePipeline)
options.annotateFn = annotate;
if (noEpiloguePeeling) {
options.supportDynamicLoops = true;
options.peelEpilogue = false;
options.predicateFn = predicateOp;
}
Expand Down