-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][SCF] Add support for loop pipeline peeling for dynamic loops. #106436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
sjw36
commented
Aug 28, 2024
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: SJW (sjw36) Changes
Full diff: https://github.com/llvm/llvm-project/pull/106436.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index d8e1cc0ecef88e..95fa7c8b0ef7d5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -94,8 +94,8 @@ struct LoopPipelinerInternal {
RewriterBase &rewriter);
/// Emits the epilogue, this creates `maxStage - 1` part which will contain
/// operations from stages [i; maxStage], where i is the part index.
- void emitEpilogue(RewriterBase &rewriter,
- llvm::SmallVector<Value> &returnValues);
+ LogicalResult emitEpilogue(RewriterBase &rewriter,
+ llvm::SmallVector<Value> &returnValues);
};
bool LoopPipelinerInternal::initializeLoopInfo(
@@ -133,10 +133,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
LDBG("--no epilogue or predicate set -> BAIL");
return false;
}
- if (dynamicLoop && peelEpilogue) {
- LDBG("--dynamic loop doesn't support epilogue yet -> BAIL");
- return false;
- }
std::vector<std::pair<Operation *, unsigned>> schedule;
options.getScheduleFn(forOp, schedule);
if (schedule.empty()) {
@@ -313,10 +309,10 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
});
int predicateIdx = i - stages[op];
if (predicates[predicateIdx]) {
+ OpBuilder::InsertionGuard insertGuard(rewriter);
newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
assert(newOp && "failed to predicate op.");
}
- rewriter.setInsertionPointAfter(newOp);
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
@@ -561,6 +557,7 @@ LogicalResult LoopPipelinerInternal::createKernel(
}
if (predicates[useStage]) {
+ OpBuilder::InsertionGuard insertGuard(rewriter);
newOp = predicateFn(rewriter, newOp, predicates[useStage]);
if (!newOp)
return failure();
@@ -568,7 +565,6 @@ LogicalResult LoopPipelinerInternal::createKernel(
for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
mapping.map(std::get<0>(values), std::get<1>(values));
}
- rewriter.setInsertionPointAfter(newOp);
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
}
@@ -640,70 +636,121 @@ LogicalResult LoopPipelinerInternal::createKernel(
return success();
}
-void LoopPipelinerInternal::emitEpilogue(
- RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
+LogicalResult
+LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
+ llvm::SmallVector<Value> &returnValues) {
+ Location loc = forOp.getLoc();
// Emit different versions of the induction variable. They will be
// removed by dead code if not used.
+
+ // bounds_range = ub - lb
+ // total_iterations = bounds_range / step + (bounds_range % step ? 1 : 0)
+ Type t = lb.getType();
+ Value minus1 =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+
+ Value const_0 =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
+ Value const_1 =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
+ Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
+ Value boundsRem = rewriter.create<arith::RemUIOp>(loc, boundsRange, step);
+ Value hasRem = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
+ boundsRem, const_0);
+ Value totalIterations = rewriter.create<arith::AddIOp>(
+ loc, rewriter.create<arith::DivUIOp>(loc, boundsRange, step),
+ rewriter.create<arith::SelectOp>(loc, hasRem, const_1, const_0));
+
+ SmallVector<Value> predicates(maxStage + 1);
for (int64_t i = 0; i < maxStage; i++) {
- Location loc = forOp.getLoc();
- Type t = lb.getType();
- Value minusOne =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
- // number of iterations = ((ub - 1) - lb) / step
- Value totalNumIteration = rewriter.create<arith::DivUIOp>(
- loc,
- rewriter.create<arith::SubIOp>(
- loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
- step);
- // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
+ // iterI = total_iters - 1 - i
+ // May go negative...
Value minusI =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
+ Value iterI = rewriter.create<arith::AddIOp>(
+ loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
+ minusI);
+ // newLastIter = lb + step * iterI
Value newlastIter = rewriter.create<arith::AddIOp>(
- loc, lb,
- rewriter.create<arith::MulIOp>(
- loc, step,
- rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI)));
+ loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
+
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
+
+ if (dynamicLoop) {
+ // pred = iterI >= 0
+ predicates[i + 1] = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, iterI, const_0);
+ }
}
+
// Emit `maxStage - 1` epilogue part that includes operations from stages
// [i; maxStage].
for (int64_t i = 1; i <= maxStage; i++) {
+ SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
for (Operation *op : opOrder) {
if (stages[op] < i)
continue;
+ unsigned currentVersion = maxStage - stages[op] + i;
+ unsigned nextVersion = currentVersion + 1;
Operation *newOp =
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
auto it = valueMapping.find(newOperand->get());
if (it != valueMapping.end()) {
- Value replacement = it->second[maxStage - stages[op] + i];
+ Value replacement = it->second[currentVersion];
newOperand->set(replacement);
}
});
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
- for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
- setValueMapping(op->getResult(destId), newOp->getResult(destId),
- maxStage - stages[op] + i);
+ if (dynamicLoop) {
+ OpBuilder::InsertionGuard insertGuard(rewriter);
+ newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
+ if (!newOp)
+ return failure();
+ }
+
+ for (auto [opRes, newRes] :
+ llvm::zip(op->getResults(), newOp->getResults())) {
+ setValueMapping(opRes, newRes, currentVersion);
// If the value is a loop carried dependency update the loop argument
// mapping and keep track of the last version to replace the original
// forOp uses.
for (OpOperand &operand :
forOp.getBody()->getTerminator()->getOpOperands()) {
- if (operand.get() != op->getResult(destId))
+ if (operand.get() != opRes)
continue;
- unsigned version = maxStage - stages[op] + i + 1;
// If the version is greater than maxStage it means it maps to the
// original forOp returned value.
- if (version > maxStage) {
- returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
- continue;
- }
- setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
- newOp->getResult(destId), version);
+ unsigned ri = operand.getOperandNumber();
+ returnValues[ri] = newRes;
+ Value mapVal = forOp.getRegionIterArgs()[ri];
+ returnMap[ri] = std::make_pair(mapVal, currentVersion);
+ if (nextVersion <= maxStage)
+ setValueMapping(mapVal, newRes, nextVersion);
+ }
+ }
+ }
+ if (dynamicLoop) {
+ // Select return values from this stage (live outs) based on predication.
+ // If the stage is valid select the peeled value, else use previous stage
+ // value.
+ for (auto pair : llvm::enumerate(returnValues)) {
+ unsigned ri = pair.index();
+ auto [mapVal, currentVersion] = returnMap[ri];
+ if (mapVal) {
+ unsigned nextVersion = currentVersion + 1;
+ Value pred = predicates[currentVersion];
+ Value prevValue = valueMapping[mapVal][currentVersion];
+ auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
+ prevValue);
+ returnValues[ri] = selOp;
+ if (nextVersion <= maxStage)
+ setValueMapping(mapVal, selOp, nextVersion);
}
}
}
}
+ return success();
}
void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
@@ -760,7 +807,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
if (options.peelEpilogue) {
// 4. Emit the epilogue after the new forOp.
rewriter.setInsertionPointAfter(newForOp);
- pipeliner.emitEpilogue(rewriter, returnValues);
+ if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
+ return failure();
}
// 5. Erase the original loop and replace the uses with the epilogue output.
if (forOp->getNumResults() > 0)
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 9687f80f5ddfc8..957dc5295c0583 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -764,11 +764,46 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// NOEPILOGUE: memref.load %[[A]][%[[IV3]]] : memref<?xf32>
// NOEPILOGUE: scf.yield %[[V2]], %[[L3]] : f32, f32
-// In case dynamic loop pipelining is off check that the transformation didn't
-// apply.
+// Check for predicated epilogue for dynamic loop.
// CHECK-LABEL: dynamic_loop(
-// CHECK-NOT: memref.load
-// CHECK: scf.for
+// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
+// CHECK: memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]]
+// CHECK: %[[ADDF_26:.*]] = arith.addf %[[ARG7]], %{{.*}}
+// CHECK: %[[MULI_27:.*]] = arith.muli %{{.*}}, %{{.*}}
+// CHECK: %[[ADDI_28:.*]] = arith.addi %[[ARG5]], %[[MULI_27]]
+// CHECK: %[[LOAD_29:.*]] = memref.load %{{.*}}[%[[ADDI_28]]]
+// CHECK: scf.yield %[[ADDF_26]], %[[LOAD_29]]
+// CHECK: }
+// CHECK: %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}}
+// CHECK: %[[REMUI_11:.*]] = arith.remui %[[SUBI_10]], %{{.*}}
+// CHECK: %[[CMPI_12:.*]] = arith.cmpi ne, %[[REMUI_11]], %{{.*}}
+// CHECK: %[[SELECT_13:.*]] = arith.select %[[CMPI_12]], %{{.*}}, %{{.*}}
+// CHECK: %[[DIVUI_14:.*]] = arith.divui %[[SUBI_10]], %{{.*}}
+// CHECK: %[[ADDI_15:.*]] = arith.addi %[[DIVUI_14]], %[[SELECT_13]]
+// CHECK: %[[ADDI_16:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1
+// CHECK: %[[MULI_17:.*]] = arith.muli %{{.*}}, %[[ADDI_16]]
+// CHECK: %[[ADDI_18:.*]] = arith.addi %{{.*}}, %[[MULI_17]]
+// CHECK: %[[CMPI_19:.*]] = arith.cmpi sge, %[[ADDI_16]], %{{.*}}
+// CHECK: %[[ADDI_20:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1
+// CHECK: %[[ADDI_21:.*]] = arith.addi %[[ADDI_20]], %{{.*}}-1
+// CHECK: %[[MULI_22:.*]] = arith.muli %{{.*}}, %[[ADDI_21]]
+// CHECK: %[[ADDI_23:.*]] = arith.addi %{{.*}}, %[[MULI_22]]
+// CHECK: %[[CMPI_24:.*]] = arith.cmpi sge, %[[ADDI_21]], %{{.*}}
+// CHECK: scf.if %[[CMPI_19]] {
+// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_23]]]
+// CHECK: } else {
+// CHECK: }
+// CHECK: %[[IF_25:.*]] = scf.if %[[CMPI_24]] -> (f32) {
+// CHECK: %[[ADDF_26:.*]] = arith.addf %{{.*}}#1, %{{.*}}
+// CHECK: scf.yield %[[ADDF_26]]
+// CHECK: } else {
+// CHECK: scf.yield %{{.*}}
+// CHECK: }
+// CHECK: scf.if %[[CMPI_24]] {
+// CHECK: memref.store %[[IF_25]], %{{.*}}[%[[ADDI_18]]]
+// CHECK: } else {
+// CHECK: }
+// CHECK: return
func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
%cf = arith.constant 1.0 : f32
scf.for %i0 = %lb to %ub step %step {
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 8a92d840ad1302..3ff7f9966e93da 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -214,12 +214,12 @@ struct TestSCFPipeliningPass
RewritePatternSet patterns(&getContext());
mlir::scf::PipeliningOption options;
options.getScheduleFn = getSchedule;
+ options.supportDynamicLoops = true;
+ options.predicateFn = predicateOp;
if (annotatePipeline)
options.annotateFn = annotate;
if (noEpiloguePeeling) {
- options.supportDynamicLoops = true;
options.peelEpilogue = false;
- options.predicateFn = predicateOp;
}
scf::populateSCFLoopPipeliningPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
8083865
to
e1dcc2b
Compare
* Allow speculative execution and predicate results per stage.
e1dcc2b
to
9be66a1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM but I'm not super familiar here; so will wait for @ThomasRaoux to finally approve.
} | ||
} | ||
} | ||
if (dynamicLoop) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a test to excercise this case?
@@ -764,11 +764,44 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub: | |||
// NOEPILOGUE: memref.load %[[A]][%[[IV3]]] : memref<?xf32> | |||
// NOEPILOGUE: scf.yield %[[V2]], %[[L3]] : f32, f32 | |||
|
|||
// In case dynamic loop pipelining is off check that the transformation didn't |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to keep a test checking this case (not pipelining when dynamic loop support is turned off?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would need to add a new switch to TestSCFUtils and probably a new test file so we don't run all the tests again without dynamic loop support. Or perhaps add it to the annotate run? Is that acceptable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, not sure if this is worth the effort to validate that the transformation is disabled. I think I'm OK if you'd want to skip it.
for (auto pair : llvm::enumerate(returnValues)) { | ||
unsigned ri = pair.index(); | ||
auto [mapVal, currentVersion] = returnMap[ri]; | ||
if (mapVal) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to predicate all the return values? I would think that we could predicate only the values that are later used outside of the loop, otherwise it is OK to speculatively execute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When maxStage > 2 there are multiple stages peeled. But if K
is only 1 only the last stage would be executed with selected results bypassing the previous peeled stages to the loop results (which would actually be the init values).
Some results may not be used outside loop, and would be optimized away. But since we capture these as we peel each iteration, they feed to the next iteration, and the final set replaces forLoop results.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I see what you mean. This will happen only for dependencies within the same stage, right? For example:
i = i+1
store(ptr, i)
If both ops are in the same stage (say: last), you need to predicate i=i+1
, otherwise once you finally get to execute store, you have wrong value of i
. But if i=i+1
would be in the previous stage, normal accounting for value versions will take care of it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, versioning takes care of that dependency. But a case where each stage returns a new value based on the old value, requires the select.
%result:2 = scf.for {...}
// Stage N-2
%s1 = mul %result#0, %c32
%sel1 = select %valid_stage_1, %s1, %result#0
// Stage N-1
%s2 = mul %sel1, %c32
%sel2 = select %valid_stage_2, %s2, %sel1
I will add an example test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes please include en example I'm not sure I understand in what case that would be needed.
My thinking is that if the value doesn't escape the loop then any uses of an op that was predicated should be also predicated, therefore we shouldn't need the select
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But a case where each stage returns a new value based on the old value, requires the select.
Could we have a check for that, instead of adding selects for all the return values? I can imagine removing them may be hard afterwards
EDIT: actually I take that back. We have spent some time with @ThomasRaoux analyzing different cases and we agree that predicates are needed for all the return values. I guess the test case for it won't hurt :) But the code looks correct!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example:
func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
%cf0 = arith.constant 1.0 : f32
%cf1 = arith.constant 33.0 : f32
%cst = arith.constant 0 : index
%res:1 = scf.for %i0 = %lb to %ub step %step iter_args (%arg0 = %cf0) -> (f32) {
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
%A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
%A2_elem = arith.mulf %A1_elem, %cf1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
scf.yield %A2_elem : f32
} { __test_pipelining_loop__ }
memref.store %res#0, %result[%cst] : memref<?xf32>
return
}
I see now the example predicates every operation using the predicateFn, not just the side-effecting ops. So this becomes:
func.func @dynamic_loop_result(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: index, %arg3: index, %arg4: index) {
%c-1 = arith.constant -1 : index
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 1.000000e+00 : f32
%cst_1 = arith.constant 3.300000e+01 : f32
%c0 = arith.constant 0 : index
%0 = arith.cmpi slt, %arg2, %arg3 : index
%1 = scf.if %0 -> (f32) {
%13 = memref.load %arg0[%arg2] : memref<?xf32>
scf.yield %13 : f32
} else {
scf.yield %cst : f32
}
%2 = arith.subi %arg3, %arg4 : index
%3:2 = scf.for %arg5 = %arg2 to %2 step %arg4 iter_args(%arg6 = %cst_0, %arg7 = %1) -> (f32, f32) {
%13 = arith.addf %arg7, %arg6 : f32
%14 = arith.mulf %13, %cst_1 : f32
%15 = arith.addi %arg5, %arg4 : index
%16 = memref.load %arg0[%15] : memref<?xf32>
scf.yield %14, %16 : f32, f32
}
%4 = arith.subi %arg3, %arg2 : index
%5 = arith.addi %4, %arg4 : index
%6 = arith.addi %5, %c-1 : index
%7 = arith.divui %6, %arg4 : index
%8 = arith.addi %7, %c-1 : index
%9 = arith.cmpi sge, %8, %arg2 : index
%10 = scf.if %9 -> (f32) {
%13 = arith.addf %3#1, %3#0 : f32
scf.yield %13 : f32
} else {
scf.yield %cst : f32
}
%11 = scf.if %9 -> (f32) {
%13 = arith.mulf %10, %cst_1 : f32
scf.yield %13 : f32
} else {
scf.yield %cst : f32
}
%12 = arith.select %9, %11, %3#0 : f32 /// redundant
memref.store %12, %arg1[%c0] : memref<?xf32>
return
}
As you can see every operations is guarded (including ops that do not produce a loop result). And it doesn't really do speculative execution then.
If only side-effecting ops are guarded and only results are selected based on stage range, results would be:
func.func @dynamic_loop_result(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: index, %arg3: index, %arg4: index) {
%c-1 = arith.constant -1 : index
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 1.000000e+00 : f32
%cst_1 = arith.constant 3.300000e+01 : f32
%c0 = arith.constant 0 : index
%0 = arith.cmpi slt, %arg2, %arg3 : index
%1 = scf.if %0 -> (f32) {
%13 = memref.load %arg0[%arg2] : memref<?xf32>
scf.yield %13 : f32
} else {
scf.yield %cst : f32
}
%2 = arith.subi %arg3, %arg4 : index
%3:2 = scf.for %arg5 = %arg2 to %2 step %arg4 iter_args(%arg6 = %cst_0, %arg7 = %1) -> (f32, f32) {
%13 = arith.addf %arg7, %arg6 : f32
%14 = arith.mulf %13, %cst_1 : f32
%15 = arith.addi %arg5, %arg4 : index
%16 = memref.load %arg0[%15] : memref<?xf32>
scf.yield %14, %16 : f32, f32
}
%4 = arith.subi %arg3, %arg2 : index
%5 = arith.addi %4, %arg4 : index
%6 = arith.addi %5, %c-1 : index
%7 = arith.divui %6, %arg4 : index
%8 = arith.addi %7, %c-1 : index
%9 = arith.cmpi sge, %8, %arg2 : index
%10 = arith.addf %3#1, %3#0 : f32
%11 = arith.mulf %10, %cst_1 : f32
%12 = arith.select %9, %11, %3#0 : f32
memref.store %12, %arg1[%c0] : memref<?xf32>
return
}
And this seems to be what the Prologue logic is doing as well (see line 343).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This is a nice improvement!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really good, thanks for the improvements
Thanks folks. I added a test for scf.for with results. |
Select epilogue results based on iteration predication and allow speculative execution. For instance, when pipelining with num_stages==3 ``` load (0) load(1) local_store(0) %res = for (0..K-1) { dot(i) load(i+2) local_store(i+1) } %d1 = dot(K-2) local_store(K-1) %s1 = select %valid_iteration1, %d1, %res#0 %d0 = dot(K-1) %s0 = select %valid_iteration0, %d0, %s1 ``` This mirrors upstream change llvm/llvm-project#106436
Select epilogue results based on iteration predication and allow speculative execution. For instance, when pipelining with num_stages==3 ``` load (0) load(1) local_store(0) %res = for (0..K-1) { dot(i) load(i+2) local_store(i+1) } %d1 = dot(K-2) local_store(K-1) %s1 = select %valid_iteration1, %d1, %res#0 %d0 = dot(K-1) %s0 = select %valid_iteration0, %d0, %s1 ``` This mirrors upstream change llvm/llvm-project#106436