|
| 1 | +//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// Transforms SCF.WhileOp's into SCF.ForOp's. |
| 10 | +// |
| 11 | +//===----------------------------------------------------------------------===// |
| 12 | + |
| 13 | +#include "mlir/Dialect/SCF/Transforms/Passes.h" |
| 14 | + |
| 15 | +#include "mlir/Dialect/Arith/IR/Arith.h" |
| 16 | +#include "mlir/Dialect/SCF/IR/SCF.h" |
| 17 | +#include "mlir/Dialect/SCF/Transforms/Patterns.h" |
| 18 | +#include "mlir/IR/Dominance.h" |
| 19 | +#include "mlir/IR/PatternMatch.h" |
| 20 | + |
| 21 | +using namespace mlir; |
| 22 | + |
| 23 | +namespace { |
| 24 | +struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> { |
| 25 | + using OpRewritePattern::OpRewritePattern; |
| 26 | + |
| 27 | + LogicalResult matchAndRewrite(scf::WhileOp loop, |
| 28 | + PatternRewriter &rewriter) const override { |
| 29 | + return upliftWhileToForLoop(rewriter, loop); |
| 30 | + } |
| 31 | +}; |
| 32 | +} // namespace |
| 33 | + |
| 34 | +FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, |
| 35 | + scf::WhileOp loop) { |
| 36 | + Block *beforeBody = loop.getBeforeBody(); |
| 37 | + if (!llvm::hasSingleElement(beforeBody->without_terminator())) |
| 38 | + return rewriter.notifyMatchFailure(loop, "Loop body must have single op"); |
| 39 | + |
| 40 | + auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front()); |
| 41 | + if (!cmp) |
| 42 | + return rewriter.notifyMatchFailure(loop, |
| 43 | + "Loop body must have single cmp op"); |
| 44 | + |
| 45 | + scf::ConditionOp beforeTerm = loop.getConditionOp(); |
| 46 | + if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult()) |
| 47 | + return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { |
| 48 | + diag << "Expected single condition use: " << *cmp; |
| 49 | + }); |
| 50 | + |
| 51 | + // All `before` block args must be directly forwarded to ConditionOp. |
| 52 | + // They will be converted to `scf.for` `iter_vars` except induction var. |
| 53 | + if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) |
| 54 | + return rewriter.notifyMatchFailure(loop, "Invalid args order"); |
| 55 | + |
| 56 | + using Pred = arith::CmpIPredicate; |
| 57 | + Pred predicate = cmp.getPredicate(); |
| 58 | + if (predicate != Pred::slt && predicate != Pred::sgt) |
| 59 | + return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { |
| 60 | + diag << "Expected 'slt' or 'sgt' predicate: " << *cmp; |
| 61 | + }); |
| 62 | + |
| 63 | + BlockArgument inductionVar; |
| 64 | + Value ub; |
| 65 | + DominanceInfo dom; |
| 66 | + |
| 67 | + // Check if cmp has a suitable form. One of the arguments must be a `before` |
| 68 | + // block arg, other must be defined outside `scf.while` and will be treated |
| 69 | + // as upper bound. |
| 70 | + for (bool reverse : {false, true}) { |
| 71 | + auto expectedPred = reverse ? Pred::sgt : Pred::slt; |
| 72 | + if (cmp.getPredicate() != expectedPred) |
| 73 | + continue; |
| 74 | + |
| 75 | + auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs(); |
| 76 | + auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs(); |
| 77 | + |
| 78 | + auto blockArg = dyn_cast<BlockArgument>(arg1); |
| 79 | + if (!blockArg || blockArg.getOwner() != beforeBody) |
| 80 | + continue; |
| 81 | + |
| 82 | + if (!dom.properlyDominates(arg2, loop)) |
| 83 | + continue; |
| 84 | + |
| 85 | + inductionVar = blockArg; |
| 86 | + ub = arg2; |
| 87 | + break; |
| 88 | + } |
| 89 | + |
| 90 | + if (!inductionVar) |
| 91 | + return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { |
| 92 | + diag << "Unrecognized cmp form: " << *cmp; |
| 93 | + }); |
| 94 | + |
| 95 | + // inductionVar must have 2 uses: one is in `cmp` and other is `condition` |
| 96 | + // arg. |
| 97 | + if (!llvm::hasNItems(inductionVar.getUses(), 2)) |
| 98 | + return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { |
| 99 | + diag << "Unrecognized induction var: " << inductionVar; |
| 100 | + }); |
| 101 | + |
| 102 | + Block *afterBody = loop.getAfterBody(); |
| 103 | + scf::YieldOp afterTerm = loop.getYieldOp(); |
| 104 | + auto argNumber = inductionVar.getArgNumber(); |
| 105 | + auto afterTermIndArg = afterTerm.getResults()[argNumber]; |
| 106 | + |
| 107 | + auto inductionVarAfter = afterBody->getArgument(argNumber); |
| 108 | + |
| 109 | + Value step; |
| 110 | + |
| 111 | + // Find suitable `addi` op inside `after` block, one of the args must be an |
| 112 | + // Induction var passed from `before` block and second arg must be defined |
| 113 | + // outside of the loop and will be considered step value. |
| 114 | + // TODO: Add `subi` support? |
| 115 | + for (auto &use : inductionVarAfter.getUses()) { |
| 116 | + auto owner = dyn_cast<arith::AddIOp>(use.getOwner()); |
| 117 | + if (!owner) |
| 118 | + continue; |
| 119 | + |
| 120 | + auto other = |
| 121 | + (inductionVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs()); |
| 122 | + if (!dom.properlyDominates(other, loop)) |
| 123 | + continue; |
| 124 | + |
| 125 | + if (afterTermIndArg != owner.getResult()) |
| 126 | + continue; |
| 127 | + |
| 128 | + step = other; |
| 129 | + break; |
| 130 | + } |
| 131 | + |
| 132 | + if (!step) |
| 133 | + return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op"); |
| 134 | + |
| 135 | + auto lb = loop.getInits()[argNumber]; |
| 136 | + |
| 137 | + assert(lb.getType().isIntOrIndex()); |
| 138 | + assert(lb.getType() == ub.getType()); |
| 139 | + assert(lb.getType() == step.getType()); |
| 140 | + |
| 141 | + llvm::SmallVector<Value> newArgs; |
| 142 | + |
| 143 | + // Populate inits for new `scf.for`, skip induction var. |
| 144 | + newArgs.reserve(loop.getInits().size()); |
| 145 | + for (auto &&[i, init] : llvm::enumerate(loop.getInits())) { |
| 146 | + if (i == argNumber) |
| 147 | + continue; |
| 148 | + |
| 149 | + newArgs.emplace_back(init); |
| 150 | + } |
| 151 | + |
| 152 | + Location loc = loop.getLoc(); |
| 153 | + |
| 154 | + // With `builder == nullptr`, ForOp::build will try to insert terminator at |
| 155 | + // the end of newly created block and we don't want it. Provide empty |
| 156 | + // dummy builder instead. |
| 157 | + auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {}; |
| 158 | + auto newLoop = |
| 159 | + rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder); |
| 160 | + |
| 161 | + Block *newBody = newLoop.getBody(); |
| 162 | + |
| 163 | + // Populate block args for `scf.for` body, move induction var to the front. |
| 164 | + newArgs.clear(); |
| 165 | + ValueRange newBodyArgs = newBody->getArguments(); |
| 166 | + for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) { |
| 167 | + if (i < argNumber) { |
| 168 | + newArgs.emplace_back(newBodyArgs[i + 1]); |
| 169 | + } else if (i == argNumber) { |
| 170 | + newArgs.emplace_back(newBodyArgs.front()); |
| 171 | + } else { |
| 172 | + newArgs.emplace_back(newBodyArgs[i]); |
| 173 | + } |
| 174 | + } |
| 175 | + |
| 176 | + rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(), |
| 177 | + newArgs); |
| 178 | + |
| 179 | + auto term = cast<scf::YieldOp>(newBody->getTerminator()); |
| 180 | + |
| 181 | + // Populate new yield args, skipping the induction var. |
| 182 | + newArgs.clear(); |
| 183 | + for (auto &&[i, arg] : llvm::enumerate(term.getResults())) { |
| 184 | + if (i == argNumber) |
| 185 | + continue; |
| 186 | + |
| 187 | + newArgs.emplace_back(arg); |
| 188 | + } |
| 189 | + |
| 190 | + OpBuilder::InsertionGuard g(rewriter); |
| 191 | + rewriter.setInsertionPoint(term); |
| 192 | + rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs); |
| 193 | + |
| 194 | + // Compute induction var value after loop execution. |
| 195 | + rewriter.setInsertionPointAfter(newLoop); |
| 196 | + Value one; |
| 197 | + if (isa<IndexType>(step.getType())) { |
| 198 | + one = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
| 199 | + } else { |
| 200 | + one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType()); |
| 201 | + } |
| 202 | + |
| 203 | + Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one); |
| 204 | + Value len = rewriter.create<arith::SubIOp>(loc, ub, lb); |
| 205 | + len = rewriter.create<arith::AddIOp>(loc, len, stepDec); |
| 206 | + len = rewriter.create<arith::DivSIOp>(loc, len, step); |
| 207 | + len = rewriter.create<arith::SubIOp>(loc, len, one); |
| 208 | + Value res = rewriter.create<arith::MulIOp>(loc, len, step); |
| 209 | + res = rewriter.create<arith::AddIOp>(loc, lb, res); |
| 210 | + |
| 211 | + // Reconstruct `scf.while` results, inserting final induction var value |
| 212 | + // into proper place. |
| 213 | + newArgs.clear(); |
| 214 | + llvm::append_range(newArgs, newLoop.getResults()); |
| 215 | + newArgs.insert(newArgs.begin() + argNumber, res); |
| 216 | + rewriter.replaceOp(loop, newArgs); |
| 217 | + return newLoop; |
| 218 | +} |
| 219 | + |
| 220 | +void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) { |
| 221 | + patterns.add<UpliftWhileOp>(patterns.getContext()); |
| 222 | +} |
0 commit comments