Skip to content

[mlir][scf] Uplift scf.while to scf.for #76108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
/// loop bounds and loop steps are canonicalized.
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);

/// Populate patterns to uplift `scf.while` ops to `scf.for`.
/// Uplifitng expects a specific ops pattern:
/// * `before` block consisting of single arith.cmp op
/// * `after` block containing arith.addi
void populateUpliftWhileToForPatterns(RewritePatternSet &patterns);

} // namespace scf
} // namespace mlir

Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,12 @@ FailureOr<WhileOp> wrapWhileLoopInZeroTripCheck(WhileOp whileOp,
RewriterBase &rewriter,
bool forceCreateCheck = false);

/// Try to uplift `scf.while` op to `scf.for`.
/// Uplifitng expects a specific ops pattern:
/// * `before` block consisting of single arith.cmp op
/// * `after` block containing arith.addi
FailureOr<ForOp> upliftWhileToForLoop(RewriterBase &rewriter, WhileOp loop);

} // namespace scf
} // namespace mlir

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
StructuralTypeConversions.cpp
TileUsingInterface.cpp
WrapInZeroTripCheck.cpp
UpliftWhileToFor.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
Expand Down
222 changes: 222 additions & 0 deletions mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Transforms SCF.WhileOp's into SCF.ForOp's.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SCF/Transforms/Passes.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"

using namespace mlir;

namespace {
struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that we'd better decouple the transformation from the pattern.

That is: expose it as a C++ API FailureOr<scf::ForOp> upliftWhileToForLoop(scf::WhileOp loop); exposed in a public header, and have the client code decide their integration point (this could connect to the Transform dialect for example).

Copy link
Contributor Author

@Hardcode84 Hardcode84 Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Do we still want to make while-to-for uplifting part of while op canonicalization later?

using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(scf::WhileOp loop,
PatternRewriter &rewriter) const override {
return upliftWhileToForLoop(rewriter, loop);
}
};
} // namespace

FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
scf::WhileOp loop) {
Block *beforeBody = loop.getBeforeBody();
if (!llvm::hasSingleElement(beforeBody->without_terminator()))
return rewriter.notifyMatchFailure(loop, "Loop body must have single op");

auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
if (!cmp)
return rewriter.notifyMatchFailure(loop,
"Loop body must have single cmp op");

scf::ConditionOp beforeTerm = loop.getConditionOp();
if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult())
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
diag << "Expected single condition use: " << *cmp;
});

// All `before` block args must be directly forwarded to ConditionOp.
// They will be converted to `scf.for` `iter_vars` except induction var.
if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
return rewriter.notifyMatchFailure(loop, "Invalid args order");

using Pred = arith::CmpIPredicate;
Pred predicate = cmp.getPredicate();
if (predicate != Pred::slt && predicate != Pred::sgt)
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
});

BlockArgument inductionVar;
Value ub;
DominanceInfo dom;

// Check if cmp has a suitable form. One of the arguments must be a `before`
// block arg, other must be defined outside `scf.while` and will be treated
// as upper bound.
for (bool reverse : {false, true}) {
auto expectedPred = reverse ? Pred::sgt : Pred::slt;
if (cmp.getPredicate() != expectedPred)
continue;

auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();

auto blockArg = dyn_cast<BlockArgument>(arg1);
if (!blockArg || blockArg.getOwner() != beforeBody)
continue;

if (!dom.properlyDominates(arg2, loop))
continue;

inductionVar = blockArg;
ub = arg2;
break;
}

if (!inductionVar)
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
diag << "Unrecognized cmp form: " << *cmp;
});

// inductionVar must have 2 uses: one is in `cmp` and other is `condition`
// arg.
if (!llvm::hasNItems(inductionVar.getUses(), 2))
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
diag << "Unrecognized induction var: " << inductionVar;
});

Block *afterBody = loop.getAfterBody();
scf::YieldOp afterTerm = loop.getYieldOp();
auto argNumber = inductionVar.getArgNumber();
auto afterTermIndArg = afterTerm.getResults()[argNumber];

auto inductionVarAfter = afterBody->getArgument(argNumber);

Value step;

// Find suitable `addi` op inside `after` block, one of the args must be an
// Induction var passed from `before` block and second arg must be defined
// outside of the loop and will be considered step value.
// TODO: Add `subi` support?
for (auto &use : inductionVarAfter.getUses()) {
auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
if (!owner)
continue;

auto other =
(inductionVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
if (!dom.properlyDominates(other, loop))
continue;

if (afterTermIndArg != owner.getResult())
continue;

step = other;
break;
}

if (!step)
return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");

auto lb = loop.getInits()[argNumber];

assert(lb.getType().isIntOrIndex());
assert(lb.getType() == ub.getType());
assert(lb.getType() == step.getType());

llvm::SmallVector<Value> newArgs;

// Populate inits for new `scf.for`, skip induction var.
newArgs.reserve(loop.getInits().size());
for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
if (i == argNumber)
continue;

newArgs.emplace_back(init);
}

Location loc = loop.getLoc();

// With `builder == nullptr`, ForOp::build will try to insert terminator at
// the end of newly created block and we don't want it. Provide empty
// dummy builder instead.
auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
auto newLoop =
rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);

Block *newBody = newLoop.getBody();

// Populate block args for `scf.for` body, move induction var to the front.
newArgs.clear();
ValueRange newBodyArgs = newBody->getArguments();
for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
if (i < argNumber) {
newArgs.emplace_back(newBodyArgs[i + 1]);
} else if (i == argNumber) {
newArgs.emplace_back(newBodyArgs.front());
} else {
newArgs.emplace_back(newBodyArgs[i]);
}
}

rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
newArgs);

auto term = cast<scf::YieldOp>(newBody->getTerminator());

// Populate new yield args, skipping the induction var.
newArgs.clear();
for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
if (i == argNumber)
continue;

newArgs.emplace_back(arg);
}

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(term);
rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs);

// Compute induction var value after loop execution.
rewriter.setInsertionPointAfter(newLoop);
Value one;
if (isa<IndexType>(step.getType())) {
one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
} else {
one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType());
}

Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
Value len = rewriter.create<arith::SubIOp>(loc, ub, lb);
len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
len = rewriter.create<arith::DivSIOp>(loc, len, step);
len = rewriter.create<arith::SubIOp>(loc, len, one);
Value res = rewriter.create<arith::MulIOp>(loc, len, step);
res = rewriter.create<arith::AddIOp>(loc, lb, res);

// Reconstruct `scf.while` results, inserting final induction var value
// into proper place.
newArgs.clear();
llvm::append_range(newArgs, newLoop.getResults());
newArgs.insert(newArgs.begin() + argNumber, res);
rewriter.replaceOp(loop, newArgs);
return newLoop;
}

void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
patterns.add<UpliftWhileOp>(patterns.getContext());
}
Loading