Skip to content

[MLIR][SCF] Add support for loop pipeline peeling for dynamic loops. #106436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 78 additions & 38 deletions mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ struct LoopPipelinerInternal {
RewriterBase &rewriter);
/// Emits the epilogue, this creates `maxStage - 1` part which will contain
/// operations from stages [i; maxStage], where i is the part index.
void emitEpilogue(RewriterBase &rewriter,
llvm::SmallVector<Value> &returnValues);
LogicalResult emitEpilogue(RewriterBase &rewriter,
llvm::SmallVector<Value> &returnValues);
};

bool LoopPipelinerInternal::initializeLoopInfo(
Expand Down Expand Up @@ -133,10 +133,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
LDBG("--no epilogue or predicate set -> BAIL");
return false;
}
if (dynamicLoop && peelEpilogue) {
LDBG("--dynamic loop doesn't support epilogue yet -> BAIL");
return false;
}
std::vector<std::pair<Operation *, unsigned>> schedule;
options.getScheduleFn(forOp, schedule);
if (schedule.empty()) {
Expand Down Expand Up @@ -313,10 +309,10 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
});
int predicateIdx = i - stages[op];
if (predicates[predicateIdx]) {
OpBuilder::InsertionGuard insertGuard(rewriter);
newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
assert(newOp && "failed to predicate op.");
}
rewriter.setInsertionPointAfter(newOp);
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
Expand Down Expand Up @@ -561,14 +557,14 @@ LogicalResult LoopPipelinerInternal::createKernel(
}

if (predicates[useStage]) {
OpBuilder::InsertionGuard insertGuard(rewriter);
newOp = predicateFn(rewriter, newOp, predicates[useStage]);
if (!newOp)
return failure();
// Remap the results to the new predicated one.
for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
mapping.map(std::get<0>(values), std::get<1>(values));
}
rewriter.setInsertionPointAfter(newOp);
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
}
Expand Down Expand Up @@ -640,70 +636,113 @@ LogicalResult LoopPipelinerInternal::createKernel(
return success();
}

void LoopPipelinerInternal::emitEpilogue(
RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
LogicalResult
LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
llvm::SmallVector<Value> &returnValues) {
Location loc = forOp.getLoc();
// Emit different versions of the induction variable. They will be
// removed by dead code if not used.

// bounds_range = ub - lb
// total_iterations = (bounds_range + step - 1) / step
Type t = lb.getType();
Value minus1 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
Value rangeIncr = rewriter.create<arith::AddIOp>(loc, boundsRange, step);
Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);

SmallVector<Value> predicates(maxStage + 1);
for (int64_t i = 0; i < maxStage; i++) {
Location loc = forOp.getLoc();
Type t = lb.getType();
Value minusOne =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
// number of iterations = ((ub - 1) - lb) / step
Value totalNumIteration = rewriter.create<arith::DivUIOp>(
loc,
rewriter.create<arith::SubIOp>(
loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
step);
// newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
// iterI = total_iters - 1 - i
// May go negative...
Value minusI =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
Value iterI = rewriter.create<arith::AddIOp>(
loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
minusI);
// newLastIter = lb + step * iterI
Value newlastIter = rewriter.create<arith::AddIOp>(
loc, lb,
rewriter.create<arith::MulIOp>(
loc, step,
rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI)));
loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));

setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);

if (dynamicLoop) {
// pred = iterI >= lb
predicates[i + 1] = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, iterI, lb);
}
}

// Emit `maxStage - 1` epilogue part that includes operations from stages
// [i; maxStage].
for (int64_t i = 1; i <= maxStage; i++) {
SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
for (Operation *op : opOrder) {
if (stages[op] < i)
continue;
unsigned currentVersion = maxStage - stages[op] + i;
unsigned nextVersion = currentVersion + 1;
Operation *newOp =
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
auto it = valueMapping.find(newOperand->get());
if (it != valueMapping.end()) {
Value replacement = it->second[maxStage - stages[op] + i];
Value replacement = it->second[currentVersion];
newOperand->set(replacement);
}
});
if (dynamicLoop) {
OpBuilder::InsertionGuard insertGuard(rewriter);
newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
if (!newOp)
return failure();
}
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
setValueMapping(op->getResult(destId), newOp->getResult(destId),
maxStage - stages[op] + i);

for (auto [opRes, newRes] :
llvm::zip(op->getResults(), newOp->getResults())) {
setValueMapping(opRes, newRes, currentVersion);
// If the value is a loop carried dependency update the loop argument
// mapping and keep track of the last version to replace the original
// forOp uses.
for (OpOperand &operand :
forOp.getBody()->getTerminator()->getOpOperands()) {
if (operand.get() != op->getResult(destId))
if (operand.get() != opRes)
continue;
unsigned version = maxStage - stages[op] + i + 1;
// If the version is greater than maxStage it means it maps to the
// original forOp returned value.
if (version > maxStage) {
returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
continue;
}
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
newOp->getResult(destId), version);
unsigned ri = operand.getOperandNumber();
returnValues[ri] = newRes;
Value mapVal = forOp.getRegionIterArgs()[ri];
returnMap[ri] = std::make_pair(mapVal, currentVersion);
if (nextVersion <= maxStage)
setValueMapping(mapVal, newRes, nextVersion);
}
}
}
if (dynamicLoop) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test to excercise this case?

// Select return values from this stage (live outs) based on predication.
// If the stage is valid select the peeled value, else use previous stage
// value.
for (auto pair : llvm::enumerate(returnValues)) {
unsigned ri = pair.index();
auto [mapVal, currentVersion] = returnMap[ri];
if (mapVal) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to predicate all the return values? I would think that we could predicate only the values that are later used outside of the loop, otherwise it is OK to speculatively execute.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When maxStage > 2 there are multiple stages peeled. But if K is only 1 only the last stage would be executed with selected results bypassing the previous peeled stages to the loop results (which would actually be the init values).

Some results may not be used outside loop, and would be optimized away. But since we capture these as we peel each iteration, they feed to the next iteration, and the final set replaces forLoop results.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I see what you mean. This will happen only for dependencies within the same stage, right? For example:

i = i+1
store(ptr, i)

If both ops are in the same stage (say: last), you need to predicate i=i+1, otherwise once you finally get to execute store, you have wrong value of i. But if i=i+1 would be in the previous stage, normal accounting for value versions will take care of it

Copy link
Contributor Author

@sjw36 sjw36 Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, versioning takes care of that dependency. But a case where each stage returns a new value based on the old value, requires the select.

%result:2 = scf.for {...}
// Stage N-2
%s1 = mul %result#0, %c32
%sel1 = select %valid_stage_1, %s1, %result#0
// Stage N-1
%s2 = mul %sel1, %c32
%sel2 = select %valid_stage_2, %s2, %sel1

I will add an example test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes please include en example I'm not sure I understand in what case that would be needed.
My thinking is that if the value doesn't escape the loop then any uses of an op that was predicated should be also predicated, therefore we shouldn't need the select

Copy link
Contributor

@pawelszczerbuk pawelszczerbuk Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But a case where each stage returns a new value based on the old value, requires the select.

Could we have a check for that, instead of adding selects for all the return values? I can imagine removing them may be hard afterwards

EDIT: actually I take that back. We have spent some time with @ThomasRaoux analyzing different cases and we agree that predicates are needed for all the return values. I guess the test case for it won't hurt :) But the code looks correct!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example:

func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
  %cf0 = arith.constant 1.0 : f32
  %cf1 = arith.constant 33.0 : f32
  %cst = arith.constant 0 : index
  %res:1 = scf.for %i0 = %lb to %ub step %step iter_args (%arg0 = %cf0) -> (f32) {
    %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
    %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
    %A2_elem = arith.mulf %A1_elem, %cf1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
    scf.yield %A2_elem : f32
  } { __test_pipelining_loop__ }
  memref.store %res#0, %result[%cst] : memref<?xf32>
  return
}

I see now the example predicates every operation using the predicateFn, not just the side-effecting ops. So this becomes:

  func.func @dynamic_loop_result(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: index, %arg3: index, %arg4: index) {
    %c-1 = arith.constant -1 : index
    %cst = arith.constant 0.000000e+00 : f32
    %cst_0 = arith.constant 1.000000e+00 : f32
    %cst_1 = arith.constant 3.300000e+01 : f32
    %c0 = arith.constant 0 : index
    %0 = arith.cmpi slt, %arg2, %arg3 : index
    %1 = scf.if %0 -> (f32) {
      %13 = memref.load %arg0[%arg2] : memref<?xf32>
      scf.yield %13 : f32
    } else {
      scf.yield %cst : f32
    }
    %2 = arith.subi %arg3, %arg4 : index
    %3:2 = scf.for %arg5 = %arg2 to %2 step %arg4 iter_args(%arg6 = %cst_0, %arg7 = %1) -> (f32, f32) {
      %13 = arith.addf %arg7, %arg6 : f32
      %14 = arith.mulf %13, %cst_1 : f32
      %15 = arith.addi %arg5, %arg4 : index
      %16 = memref.load %arg0[%15] : memref<?xf32>
      scf.yield %14, %16 : f32, f32
    }
    %4 = arith.subi %arg3, %arg2 : index
    %5 = arith.addi %4, %arg4 : index
    %6 = arith.addi %5, %c-1 : index
    %7 = arith.divui %6, %arg4 : index
    %8 = arith.addi %7, %c-1 : index
    %9 = arith.cmpi sge, %8, %arg2 : index
    %10 = scf.if %9 -> (f32) {
      %13 = arith.addf %3#1, %3#0 : f32
      scf.yield %13 : f32
    } else {
      scf.yield %cst : f32
    }
    %11 = scf.if %9 -> (f32) {
      %13 = arith.mulf %10, %cst_1 : f32
      scf.yield %13 : f32
    } else {
      scf.yield %cst : f32
    }
    %12 = arith.select %9, %11, %3#0 : f32   /// redundant
    memref.store %12, %arg1[%c0] : memref<?xf32>
    return
  }

As you can see every operations is guarded (including ops that do not produce a loop result). And it doesn't really do speculative execution then.

If only side-effecting ops are guarded and only results are selected based on stage range, results would be:

  func.func @dynamic_loop_result(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: index, %arg3: index, %arg4: index) {
    %c-1 = arith.constant -1 : index
    %cst = arith.constant 0.000000e+00 : f32
    %cst_0 = arith.constant 1.000000e+00 : f32
    %cst_1 = arith.constant 3.300000e+01 : f32
    %c0 = arith.constant 0 : index
    %0 = arith.cmpi slt, %arg2, %arg3 : index
    %1 = scf.if %0 -> (f32) {
      %13 = memref.load %arg0[%arg2] : memref<?xf32>
      scf.yield %13 : f32
    } else {
      scf.yield %cst : f32
    }
    %2 = arith.subi %arg3, %arg4 : index
    %3:2 = scf.for %arg5 = %arg2 to %2 step %arg4 iter_args(%arg6 = %cst_0, %arg7 = %1) -> (f32, f32) {
      %13 = arith.addf %arg7, %arg6 : f32
      %14 = arith.mulf %13, %cst_1 : f32
      %15 = arith.addi %arg5, %arg4 : index
      %16 = memref.load %arg0[%15] : memref<?xf32>
      scf.yield %14, %16 : f32, f32
    }
    %4 = arith.subi %arg3, %arg2 : index
    %5 = arith.addi %4, %arg4 : index
    %6 = arith.addi %5, %c-1 : index
    %7 = arith.divui %6, %arg4 : index
    %8 = arith.addi %7, %c-1 : index
    %9 = arith.cmpi sge, %8, %arg2 : index
    %10 = arith.addf %3#1, %3#0 : f32
    %11 = arith.mulf %10, %cst_1 : f32
    %12 = arith.select %9, %11, %3#0 : f32
    memref.store %12, %arg1[%c0] : memref<?xf32>
    return
  }

And this seems to be what the Prologue logic is doing as well (see line 343).

unsigned nextVersion = currentVersion + 1;
Value pred = predicates[currentVersion];
Value prevValue = valueMapping[mapVal][currentVersion];
auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
prevValue);
returnValues[ri] = selOp;
if (nextVersion <= maxStage)
setValueMapping(mapVal, selOp, nextVersion);
}
}
}
}
return success();
}

void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
Expand Down Expand Up @@ -760,7 +799,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
if (options.peelEpilogue) {
// 4. Emit the epilogue after the new forOp.
rewriter.setInsertionPointAfter(newForOp);
pipeliner.emitEpilogue(rewriter, returnValues);
if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
return failure();
}
// 5. Erase the original loop and replace the uses with the epilogue output.
if (forOp->getNumResults() > 0)
Expand Down
103 changes: 99 additions & 4 deletions mlir/test/Dialect/SCF/loop-pipelining.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -764,11 +764,44 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// NOEPILOGUE: memref.load %[[A]][%[[IV3]]] : memref<?xf32>
// NOEPILOGUE: scf.yield %[[V2]], %[[L3]] : f32, f32

// In case dynamic loop pipelining is off check that the transformation didn't
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to keep a test checking this case (not pipelining when dynamic loop support is turned off?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would need to add a new switch to TestSCFUtils and probably a new test file so we don't run all the tests again without dynamic loop support. Or perhaps add it to the annotate run? Is that acceptable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not sure if this is worth the effort to validate that the transformation is disabled. I think I'm OK if you'd want to skip it.

// apply.
// Check for predicated epilogue for dynamic loop.
// CHECK-LABEL: dynamic_loop(
// CHECK-NOT: memref.load
// CHECK: scf.for
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
// CHECK: memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]]
// CHECK: %[[ADDF_24:.*]] = arith.addf %[[ARG7]], %{{.*}}
// CHECK: %[[MULI_25:.*]] = arith.muli %{{.*}}, %{{.*}}
// CHECK: %[[ADDI_26:.*]] = arith.addi %[[ARG5]], %[[MULI_25]]
// CHECK: %[[LOAD_27:.*]] = memref.load %{{.*}}[%[[ADDI_26]]]
// CHECK: scf.yield %[[ADDF_24]], %[[LOAD_27]]
// CHECK: }
// CHECK: %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}}
// CHECK: %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %{{.*}}
// CHECK: %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %{{.*}}-1
// CHECK: %[[DIVUI_13:.*]] = arith.divui %[[ADDI_12]], %{{.*}}
// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}}
// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}}
// CHECK: scf.if %[[CMPI_17]] {
// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
// CHECK: } else {
// CHECK: }
// CHECK: %[[IF_23:.*]] = scf.if %[[CMPI_22]] -> (f32) {
// CHECK: %[[ADDF_24:.*]] = arith.addf %{{.*}}#1, %{{.*}}
// CHECK: scf.yield %[[ADDF_24]]
// CHECK: } else {
// CHECK: scf.yield %{{.*}}
// CHECK: }
// CHECK: scf.if %[[CMPI_22]] {
// CHECK: memref.store %[[IF_23]], %{{.*}}[%[[ADDI_16]]]
// CHECK: } else {
// CHECK: }
// CHECK: return
func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
%cf = arith.constant 1.0 : f32
scf.for %i0 = %lb to %ub step %step {
Expand All @@ -781,6 +814,68 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %

// -----

// NOEPILOGUE-LABEL: func.func @dynamic_loop_result
// NOEPILOGUE: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
// NOEPILOGUE: %[[SUBI_3:.*]] = arith.subi %{{.*}}, %{{.*}}
// NOEPILOGUE: %[[CMPI_4:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_3]]
// NOEPILOGUE: %[[ADDF_5:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
// NOEPILOGUE: %[[MULF_6:.*]] = arith.mulf %[[ADDF_5]], %{{.*}}
// NOEPILOGUE: %[[ADDI_7:.*]] = arith.addi %[[ARG5]], %{{.*}}
// NOEPILOGUE: %[[IF_8:.*]] = scf.if %[[CMPI_4]]
// NOEPILOGUE: %[[LOAD_9:.*]] = memref.load %{{.*}}[%[[ADDI_7]]]
// NOEPILOGUE: scf.yield %[[LOAD_9]]
// NOEPILOGUE: } else {
// NOEPILOGUE: scf.yield %{{.*}}
// NOEPILOGUE: }
// NOEPILOGUE: scf.yield %[[MULF_6]], %[[IF_8]]
// NOEPILOGUE: }
// NOEPILOGUE: memref.store %{{.*}}#0, %{{.*}}[%{{.*}}]

// Check for predicated epilogue for dynamic loop.
// CHECK-LABEL: func.func @dynamic_loop_result
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
// CHECK: %[[ADDF_13:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
// CHECK: %[[MULF_14:.*]] = arith.mulf %[[ADDF_13]], %{{.*}}
// CHECK: %[[ADDI_15:.*]] = arith.addi %[[ARG5]], %{{.*}}
// CHECK: %[[LOAD_16:.*]] = memref.load %{{.*}}[%[[ADDI_15]]]
// CHECK: scf.yield %[[MULF_14]], %[[LOAD_16]]
// CHECK: }
// CHECK: %[[SUBI_4:.*]] = arith.subi %{{.*}}, %{{.*}}
// CHECK: %[[ADDI_5:.*]] = arith.addi %[[SUBI_4]], %{{.*}}
// CHECK: %[[ADDI_6:.*]] = arith.addi %[[ADDI_5]], %{{.*}}-1
// CHECK: %[[DIVUI_7:.*]] = arith.divui %[[ADDI_6]], %{{.*}}
// CHECK: %[[ADDI_8:.*]] = arith.addi %[[DIVUI_7]], %{{.*}}-1
// CHECK: %[[CMPI_9:.*]] = arith.cmpi sge, %[[ADDI_8]], %{{.*}}
// CHECK: %[[IF_10:.*]] = scf.if %[[CMPI_9]]
// CHECK: %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
// CHECK: scf.yield %[[ADDF_13]]
// CHECK: } else {
// CHECK: scf.yield %{{.*}}
// CHECK: }
// CHECK: %[[IF_11:.*]] = scf.if %[[CMPI_9]]
// CHECK: %[[MULF_13:.*]] = arith.mulf %[[IF_10]], %{{.*}}
// CHECK: scf.yield %[[MULF_13]]
// CHECK: } else {
// CHECK: scf.yield %{{.*}}
// CHECK: }
// CHECK: %[[SELECT_12:.*]] = arith.select %[[CMPI_9]], %[[IF_11]], %{{.*}}#0
// CHECK: memref.store %[[SELECT_12]], %{{.*}}[%{{.*}}]
func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
%cf0 = arith.constant 1.0 : f32
%cf1 = arith.constant 33.0 : f32
%cst = arith.constant 0 : index
%res:1 = scf.for %i0 = %lb to %ub step %step iter_args (%arg0 = %cf0) -> (f32) {
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
%A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
%A2_elem = arith.mulf %A1_elem, %cf1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
scf.yield %A2_elem : f32
} { __test_pipelining_loop__ }
memref.store %res#0, %result[%cst] : memref<?xf32>
return
}

// -----

// CHECK-LABEL: yield_constant_loop(
// CHECK-SAME: %[[A:.*]]: memref<?xf32>) -> f32 {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,12 @@ struct TestSCFPipeliningPass
RewritePatternSet patterns(&getContext());
mlir::scf::PipeliningOption options;
options.getScheduleFn = getSchedule;
options.supportDynamicLoops = true;
options.predicateFn = predicateOp;
if (annotatePipeline)
options.annotateFn = annotate;
if (noEpiloguePeeling) {
options.supportDynamicLoops = true;
options.peelEpilogue = false;
options.predicateFn = predicateOp;
}
scf::populateSCFLoopPipeliningPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
Expand Down
Loading