-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][scf] Uplift scf.while
to scf.for
#76108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
4414fbc
[mlir][scf] Uplift `scf.while` to `scf.for`
Hardcode84 9333a9b
Support non-index types
Hardcode84 923f2cc
cleanup
Hardcode84 e7119a4
Renamings and comments
Hardcode84 bc88fbb
renaming
Hardcode84 3dc8695
Renamed to test pass
Hardcode84 9c054f4
fix typo
Hardcode84 e439c36
fix check
Hardcode84 f619c6f
upliftWhileToForLoop func
Hardcode84 e062c93
move test pass
Hardcode84 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Transforms SCF.WhileOp's into SCF.ForOp's. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/SCF/Transforms/Passes.h" | ||
|
||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/SCF/IR/SCF.h" | ||
#include "mlir/Dialect/SCF/Transforms/Patterns.h" | ||
#include "mlir/IR/Dominance.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(scf::WhileOp loop, | ||
PatternRewriter &rewriter) const override { | ||
return upliftWhileToForLoop(rewriter, loop); | ||
} | ||
}; | ||
} // namespace | ||
|
||
FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, | ||
scf::WhileOp loop) { | ||
Block *beforeBody = loop.getBeforeBody(); | ||
if (!llvm::hasSingleElement(beforeBody->without_terminator())) | ||
return rewriter.notifyMatchFailure(loop, "Loop body must have single op"); | ||
|
||
auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front()); | ||
if (!cmp) | ||
return rewriter.notifyMatchFailure(loop, | ||
"Loop body must have single cmp op"); | ||
|
||
scf::ConditionOp beforeTerm = loop.getConditionOp(); | ||
if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult()) | ||
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { | ||
diag << "Expected single condition use: " << *cmp; | ||
}); | ||
|
||
// All `before` block args must be directly forwarded to ConditionOp. | ||
// They will be converted to `scf.for` `iter_vars` except induction var. | ||
if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) | ||
return rewriter.notifyMatchFailure(loop, "Invalid args order"); | ||
|
||
using Pred = arith::CmpIPredicate; | ||
Pred predicate = cmp.getPredicate(); | ||
if (predicate != Pred::slt && predicate != Pred::sgt) | ||
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { | ||
diag << "Expected 'slt' or 'sgt' predicate: " << *cmp; | ||
}); | ||
|
||
BlockArgument inductionVar; | ||
Value ub; | ||
DominanceInfo dom; | ||
|
||
// Check if cmp has a suitable form. One of the arguments must be a `before` | ||
// block arg, other must be defined outside `scf.while` and will be treated | ||
// as upper bound. | ||
for (bool reverse : {false, true}) { | ||
auto expectedPred = reverse ? Pred::sgt : Pred::slt; | ||
if (cmp.getPredicate() != expectedPred) | ||
continue; | ||
|
||
auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs(); | ||
auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs(); | ||
|
||
auto blockArg = dyn_cast<BlockArgument>(arg1); | ||
if (!blockArg || blockArg.getOwner() != beforeBody) | ||
continue; | ||
|
||
if (!dom.properlyDominates(arg2, loop)) | ||
continue; | ||
|
||
inductionVar = blockArg; | ||
ub = arg2; | ||
break; | ||
} | ||
|
||
if (!inductionVar) | ||
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { | ||
diag << "Unrecognized cmp form: " << *cmp; | ||
}); | ||
|
||
// inductionVar must have 2 uses: one is in `cmp` and other is `condition` | ||
// arg. | ||
if (!llvm::hasNItems(inductionVar.getUses(), 2)) | ||
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { | ||
diag << "Unrecognized induction var: " << inductionVar; | ||
}); | ||
|
||
Block *afterBody = loop.getAfterBody(); | ||
scf::YieldOp afterTerm = loop.getYieldOp(); | ||
auto argNumber = inductionVar.getArgNumber(); | ||
auto afterTermIndArg = afterTerm.getResults()[argNumber]; | ||
|
||
auto inductionVarAfter = afterBody->getArgument(argNumber); | ||
|
||
Value step; | ||
|
||
// Find suitable `addi` op inside `after` block, one of the args must be an | ||
// Induction var passed from `before` block and second arg must be defined | ||
// outside of the loop and will be considered step value. | ||
// TODO: Add `subi` support? | ||
for (auto &use : inductionVarAfter.getUses()) { | ||
auto owner = dyn_cast<arith::AddIOp>(use.getOwner()); | ||
if (!owner) | ||
continue; | ||
|
||
auto other = | ||
(inductionVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs()); | ||
if (!dom.properlyDominates(other, loop)) | ||
continue; | ||
|
||
if (afterTermIndArg != owner.getResult()) | ||
continue; | ||
|
||
step = other; | ||
break; | ||
} | ||
|
||
if (!step) | ||
return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op"); | ||
|
||
auto lb = loop.getInits()[argNumber]; | ||
|
||
assert(lb.getType().isIntOrIndex()); | ||
assert(lb.getType() == ub.getType()); | ||
assert(lb.getType() == step.getType()); | ||
|
||
llvm::SmallVector<Value> newArgs; | ||
|
||
// Populate inits for new `scf.for`, skip induction var. | ||
newArgs.reserve(loop.getInits().size()); | ||
for (auto &&[i, init] : llvm::enumerate(loop.getInits())) { | ||
if (i == argNumber) | ||
continue; | ||
|
||
newArgs.emplace_back(init); | ||
} | ||
|
||
Location loc = loop.getLoc(); | ||
|
||
// With `builder == nullptr`, ForOp::build will try to insert terminator at | ||
// the end of newly created block and we don't want it. Provide empty | ||
// dummy builder instead. | ||
auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {}; | ||
auto newLoop = | ||
rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder); | ||
|
||
Block *newBody = newLoop.getBody(); | ||
|
||
// Populate block args for `scf.for` body, move induction var to the front. | ||
newArgs.clear(); | ||
ValueRange newBodyArgs = newBody->getArguments(); | ||
for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) { | ||
if (i < argNumber) { | ||
newArgs.emplace_back(newBodyArgs[i + 1]); | ||
} else if (i == argNumber) { | ||
newArgs.emplace_back(newBodyArgs.front()); | ||
} else { | ||
newArgs.emplace_back(newBodyArgs[i]); | ||
} | ||
} | ||
|
||
rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(), | ||
newArgs); | ||
|
||
auto term = cast<scf::YieldOp>(newBody->getTerminator()); | ||
|
||
// Populate new yield args, skipping the induction var. | ||
newArgs.clear(); | ||
for (auto &&[i, arg] : llvm::enumerate(term.getResults())) { | ||
if (i == argNumber) | ||
continue; | ||
|
||
newArgs.emplace_back(arg); | ||
} | ||
|
||
OpBuilder::InsertionGuard g(rewriter); | ||
rewriter.setInsertionPoint(term); | ||
rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs); | ||
|
||
// Compute induction var value after loop execution. | ||
rewriter.setInsertionPointAfter(newLoop); | ||
Value one; | ||
if (isa<IndexType>(step.getType())) { | ||
one = rewriter.create<arith::ConstantIndexOp>(loc, 1); | ||
} else { | ||
one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType()); | ||
} | ||
|
||
Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one); | ||
Value len = rewriter.create<arith::SubIOp>(loc, ub, lb); | ||
len = rewriter.create<arith::AddIOp>(loc, len, stepDec); | ||
len = rewriter.create<arith::DivSIOp>(loc, len, step); | ||
len = rewriter.create<arith::SubIOp>(loc, len, one); | ||
Value res = rewriter.create<arith::MulIOp>(loc, len, step); | ||
res = rewriter.create<arith::AddIOp>(loc, lb, res); | ||
|
||
// Reconstruct `scf.while` results, inserting final induction var value | ||
// into proper place. | ||
newArgs.clear(); | ||
llvm::append_range(newArgs, newLoop.getResults()); | ||
newArgs.insert(newArgs.begin() + argNumber, res); | ||
rewriter.replaceOp(loop, newArgs); | ||
return newLoop; | ||
} | ||
|
||
void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) { | ||
patterns.add<UpliftWhileOp>(patterns.getContext()); | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that we'd better decouple the transformation from the pattern.
That is: expose it as a C++ API
FailureOr<scf::ForOp> upliftWhileToForLoop(scf::WhileOp loop);
exposed in a public header, and have the client code decide their integration point (this could connect to the Transform dialect for example).Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Do we still want to make while-to-for uplifting part of while op canonicalization later?