-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][scf] Uplift scf.while
to scf.for
#76108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) ChangesAdd uplifting from This uplifting expects a very specific ops pattern:
We also have a set of patterns to cleanup This is part of upstreaming Original code: https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/PromoteToParallel.cpp Full diff: https://github.com/llvm/llvm-project/pull/76108.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 350611ad86873d..ec28bb0b8b8aa8 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -154,4 +154,20 @@ def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
}];
}
+def SCFUpliftWhileToFor : Pass<"scf-uplift-while-to-for"> {
+ let summary = "Uplift scf.while ops to scf.for";
+ let description = [{
+ This pass tries to uplift `scf.while` ops to `scf.for` if they have a
+ compatible form. `scf.while` are left unchanged if uplifting is not
+ possible.
+ }];
+
+ let options = [
+ Option<"indexBitWidth", "index-bitwidth", "unsigned",
+ /*default=*/"64",
+ "Bitwidth of index type.">,
+ ];
+ }
+
+
#endif // MLIR_DIALECT_SCF_PASSES
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 5c0d5643c01986..9f3cdd93071ea9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -79,6 +79,10 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
/// loop bounds and loop steps are canonicalized.
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
+/// Populate patterns to uplift `scf.while` ops to `scf.for`.
+void populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
+ unsigned indexBitwidth);
+
} // namespace scf
} // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index fdaeb2fad9afa4..7643bab80a1308 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
ParallelLoopTiling.cpp
StructuralTypeConversions.cpp
TileUsingInterface.cpp
+ UpliftWhileToFor.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
new file mode 100644
index 00000000000000..cd16b622504953
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -0,0 +1,250 @@
+//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.WhileOp's into SCF.ForOp's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFUPLIFTWHILETOFOR
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+static bool checkIndexType(arith::CmpIOp op, unsigned indexBitWidth) {
+ auto type = op.getLhs().getType();
+ if (isa<mlir::IndexType>(type))
+ return true;
+
+ if (type.isSignlessInteger(indexBitWidth))
+ return true;
+
+ return false;
+}
+
+namespace {
+struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
+ UpliftWhileOp(MLIRContext *context, unsigned indexBitWidth_)
+ : OpRewritePattern<scf::WhileOp>(context), indexBitWidth(indexBitWidth_) {
+ }
+
+ LogicalResult matchAndRewrite(scf::WhileOp loop,
+ PatternRewriter &rewriter) const override {
+ Block *beforeBody = loop.getBeforeBody();
+ if (!llvm::hasSingleElement(beforeBody->without_terminator()))
+ return rewriter.notifyMatchFailure(loop, "Loop body must have single op");
+
+ auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
+ if (!cmp)
+ return rewriter.notifyMatchFailure(loop,
+ "Loop body must have single cmp op");
+
+ auto beforeTerm = cast<scf::ConditionOp>(beforeBody->getTerminator());
+ if (!llvm::hasSingleElement(cmp->getUses()) &&
+ beforeTerm.getCondition() == cmp.getResult())
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Expected single condiditon use: " << *cmp;
+ });
+
+ if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
+ return rewriter.notifyMatchFailure(loop, "Invalid args order");
+
+ using Pred = arith::CmpIPredicate;
+ auto predicate = cmp.getPredicate();
+ if (predicate != Pred::slt && predicate != Pred::sgt)
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
+ });
+
+ if (!checkIndexType(cmp, indexBitWidth))
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Expected index-like type: " << *cmp;
+ });
+
+ BlockArgument iterVar;
+ Value end;
+ DominanceInfo dom;
+ for (bool reverse : {false, true}) {
+ auto expectedPred = reverse ? Pred::sgt : Pred::slt;
+ if (cmp.getPredicate() != expectedPred)
+ continue;
+
+ auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
+ auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
+
+ auto blockArg = dyn_cast<BlockArgument>(arg1);
+ if (!blockArg || blockArg.getOwner() != beforeBody)
+ continue;
+
+ if (!dom.properlyDominates(arg2, loop))
+ continue;
+
+ iterVar = blockArg;
+ end = arg2;
+ break;
+ }
+
+ if (!iterVar)
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Unrecognized cmp form: " << *cmp;
+ });
+
+ if (!llvm::hasNItems(iterVar.getUses(), 2))
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Unrecognized iter var: " << iterVar;
+ });
+
+ Block *afterBody = loop.getAfterBody();
+ auto afterTerm = cast<scf::YieldOp>(afterBody->getTerminator());
+ auto argNumber = iterVar.getArgNumber();
+ auto afterTermIterArg = afterTerm.getResults()[argNumber];
+
+ auto iterVarAfter = afterBody->getArgument(argNumber);
+
+ Value step;
+ for (auto &use : iterVarAfter.getUses()) {
+ auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
+ if (!owner)
+ continue;
+
+ auto other =
+ (iterVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
+ if (!dom.properlyDominates(other, loop))
+ continue;
+
+ if (afterTermIterArg != owner.getResult())
+ continue;
+
+ step = other;
+ break;
+ }
+
+ if (!step)
+ return rewriter.notifyMatchFailure(loop,
+ "Didn't found suitable 'add' op");
+
+ auto begin = loop.getInits()[argNumber];
+
+ auto loc = loop.getLoc();
+ auto indexType = rewriter.getIndexType();
+ auto toIndex = [&](Value val) -> Value {
+ if (val.getType() != indexType)
+ return rewriter.create<arith::IndexCastOp>(loc, indexType, val);
+
+ return val;
+ };
+ begin = toIndex(begin);
+ end = toIndex(end);
+ step = toIndex(step);
+
+ llvm::SmallVector<Value> mapping;
+ mapping.reserve(loop.getInits().size());
+ for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
+ if (i == argNumber)
+ continue;
+
+ mapping.emplace_back(init);
+ }
+
+ auto emptyBuidler = [](OpBuilder &, Location, Value, ValueRange) {};
+ auto newLoop = rewriter.create<scf::ForOp>(loc, begin, end, step, mapping,
+ emptyBuidler);
+
+ Block *newBody = newLoop.getBody();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointToStart(newBody);
+ Value newIterVar = newBody->getArgument(0);
+ if (newIterVar.getType() != iterVar.getType())
+ newIterVar = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(),
+ newIterVar);
+
+ mapping.clear();
+ auto newArgs = newBody->getArguments();
+ for (auto i : llvm::seq<size_t>(0, newArgs.size())) {
+ if (i < argNumber) {
+ mapping.emplace_back(newArgs[i + 1]);
+ } else if (i == argNumber) {
+ Value arg = newArgs.front();
+ if (arg.getType() != iterVar.getType())
+ arg =
+ rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), arg);
+ mapping.emplace_back(arg);
+ } else {
+ mapping.emplace_back(newArgs[i]);
+ }
+ }
+
+ rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
+ mapping);
+
+ auto term = cast<scf::YieldOp>(newBody->getTerminator());
+
+ mapping.clear();
+ for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
+ if (i == argNumber)
+ continue;
+
+ mapping.emplace_back(arg);
+ }
+
+ rewriter.setInsertionPoint(term);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(term, mapping);
+
+ rewriter.setInsertionPointAfter(newLoop);
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
+ Value len = rewriter.create<arith::SubIOp>(loc, end, begin);
+ len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
+ len = rewriter.create<arith::DivSIOp>(loc, len, step);
+ len = rewriter.create<arith::SubIOp>(loc, len, one);
+ Value res = rewriter.create<arith::MulIOp>(loc, len, step);
+ res = rewriter.create<arith::AddIOp>(loc, begin, res);
+ if (res.getType() != iterVar.getType())
+ res = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), res);
+
+ mapping.clear();
+ llvm::append_range(mapping, newLoop.getResults());
+ mapping.insert(mapping.begin() + argNumber, res);
+ rewriter.replaceOp(loop, mapping);
+ return success();
+ }
+
+private:
+ unsigned indexBitWidth = 0;
+};
+
+struct SCFUpliftWhileToFor final
+ : impl::SCFUpliftWhileToForBase<SCFUpliftWhileToFor> {
+ using SCFUpliftWhileToForBase::SCFUpliftWhileToForBase;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ MLIRContext *ctx = op->getContext();
+ RewritePatternSet patterns(ctx);
+ mlir::scf::populateUpliftWhileToForPatterns(patterns, this->indexBitWidth);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
+ unsigned indexBitwidth) {
+ patterns.add<UpliftWhileOp>(patterns.getContext(), indexBitwidth);
+}
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
new file mode 100644
index 00000000000000..52a5c0f3cd6347
--- /dev/null
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -0,0 +1,162 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-uplift-while-to-for{index-bitwidth=64}))' -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+ %1 = arith.cmpi slt, %arg3, %arg1 : index
+ scf.condition(%1) %arg3 : index
+ } do {
+ ^bb0(%arg3: index):
+ "test.test1"(%arg3) : (index) -> ()
+ %added = arith.addi %arg3, %arg2 : index
+ "test.test2"(%added) : (index) -> ()
+ scf.yield %added : index
+ }
+ return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+// CHECK: "test.test1"(%[[I]]) : (index) -> ()
+// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+// CHECK: "test.test2"(%[[INC]]) : (index) -> ()
+// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+// CHECK: return %[[R7]] : index
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+ %1 = arith.cmpi sgt, %arg1, %arg3 : index
+ scf.condition(%1) %arg3 : index
+ } do {
+ ^bb0(%arg3: index):
+ "test.test1"(%arg3) : (index) -> ()
+ %added = arith.addi %arg3, %arg2 : index
+ "test.test2"(%added) : (index) -> ()
+ scf.yield %added : index
+ }
+ return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+// CHECK: "test.test1"(%[[I]]) : (index) -> ()
+// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+// CHECK: "test.test2"(%[[INC]]) : (index) -> ()
+// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+// CHECK: return %[[R7]] : index
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+ %1 = arith.cmpi slt, %arg3, %arg1 : index
+ scf.condition(%1) %arg3 : index
+ } do {
+ ^bb0(%arg3: index):
+ "test.test1"(%arg3) : (index) -> ()
+ %added = arith.addi %arg2, %arg3 : index
+ "test.test2"(%added) : (index) -> ()
+ scf.yield %added : index
+ }
+ return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+// CHECK: "test.test1"(%[[I]]) : (index) -> ()
+// CHECK: %[[INC:.*]] = arith.addi %[[STEP]], %[[I]] : index
+// CHECK: "test.test2"(%[[INC]]) : (index) -> ()
+// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+// CHECK: return %[[R7]] : index
+
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) {
+ %c1 = arith.constant 1 : i32
+ %c2 = arith.constant 2.0 : f32
+ %0:3 = scf.while (%arg4 = %c1, %arg3 = %arg0, %arg5 = %c2) : (i32, index, f32) -> (i32, index, f32) {
+ %1 = arith.cmpi slt, %arg3, %arg1 : index
+ scf.condition(%1) %arg4, %arg3, %arg5 : i32, index, f32
+ } do {
+ ^bb0(%arg4: i32, %arg3: index, %arg5: f32):
+ %1 = "test.test1"(%arg4) : (i32) -> i32
+ %added = arith.addi %arg3, %arg2 : index
+ %2 = "test.test2"(%arg5) : (f32) -> f32
+ scf.yield %1, %added, %2 : i32, index, f32
+ }
+ return %0#0, %0#2 : i32, f32
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> (i32, f32)
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]]
+// CHECK-SAME: iter_args(%[[ARG1:.*]] = %[[C1]], %[[ARG2:.*]] = %[[C2]]) -> (i32, f32) {
+// CHECK: %[[T1:.*]] = "test.test1"(%[[ARG1]]) : (i32) -> i32
+// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
+// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32
+// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32
+
+// -----
+
+func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
+ %0 = scf.while (%arg3 = %arg0) : (i64) -> (i64) {
+ %1 = arith.cmpi slt, %arg3, %arg1 : i64
+ scf.condition(%1) %arg3 : i64
+ } do {
+ ^bb0(%arg3: i64):
+ "test.test1"(%arg3) : (i64) -> ()
+ %added = arith.addi %arg3, %arg2 : i64
+ "test.test2"(%added) : (i64) -> ()
+ scf.yield %added : i64
+ }
+ return %0 : i64
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGINI:.*]]: i64, %[[ENDI:.*]]: i64, %[[STEPI:.*]]: i64) -> i64
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[BEGIN:.*]] = arith.index_cast %[[BEGINI]] : i64 to index
+// CHECK: %[[END:.*]] = arith.index_cast %[[ENDI]] : i64 to index
+// CHECK: %[[STEP:.*]] = arith.index_cast %[[STEPI]] : i64 to index
+// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+// CHECK: %[[II:.*]] = arith.index_cast %[[I]] : index to i64
+// CHECK: "test.test1"(%[[II]]) : (i64) -> ()
+// CHECK: %[[INC:.*]] = arith.addi %[[II]], %[[STEPI]] : i64
+// CHECK: "test.test2"(%[[INC]]) : (i64) -> ()
+// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+// CHECK: %[[RES:.*]] = arith.index_cast %[[R7]] : index to i64
+// CHECK: return %[[RES]] : i64
|
@llvm/pr-subscribers-mlir-scf Author: Ivan Butygin (Hardcode84) ChangesAdd uplifting from This uplifting expects a very specific ops pattern:
We also have a set of patterns to cleanup This is part of upstreaming Original code: https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/PromoteToParallel.cpp Full diff: https://github.com/llvm/llvm-project/pull/76108.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 350611ad86873d..ec28bb0b8b8aa8 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -154,4 +154,20 @@ def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
}];
}
+def SCFUpliftWhileToFor : Pass<"scf-uplift-while-to-for"> {
+ let summary = "Uplift scf.while ops to scf.for";
+ let description = [{
+ This pass tries to uplift `scf.while` ops to `scf.for` if they have a
+ compatible form. `scf.while` are left unchanged if uplifting is not
+ possible.
+ }];
+
+ let options = [
+ Option<"indexBitWidth", "index-bitwidth", "unsigned",
+ /*default=*/"64",
+ "Bitwidth of index type.">,
+ ];
+ }
+
+
#endif // MLIR_DIALECT_SCF_PASSES
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 5c0d5643c01986..9f3cdd93071ea9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -79,6 +79,10 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
/// loop bounds and loop steps are canonicalized.
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
+/// Populate patterns to uplift `scf.while` ops to `scf.for`.
+void populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
+ unsigned indexBitwidth);
+
} // namespace scf
} // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index fdaeb2fad9afa4..7643bab80a1308 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
ParallelLoopTiling.cpp
StructuralTypeConversions.cpp
TileUsingInterface.cpp
+ UpliftWhileToFor.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
new file mode 100644
index 00000000000000..cd16b622504953
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -0,0 +1,250 @@
+//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.WhileOp's into SCF.ForOp's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFUPLIFTWHILETOFOR
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+static bool checkIndexType(arith::CmpIOp op, unsigned indexBitWidth) {
+ auto type = op.getLhs().getType();
+ if (isa<mlir::IndexType>(type))
+ return true;
+
+ if (type.isSignlessInteger(indexBitWidth))
+ return true;
+
+ return false;
+}
+
+namespace {
+struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
+ UpliftWhileOp(MLIRContext *context, unsigned indexBitWidth_)
+ : OpRewritePattern<scf::WhileOp>(context), indexBitWidth(indexBitWidth_) {
+ }
+
+ LogicalResult matchAndRewrite(scf::WhileOp loop,
+ PatternRewriter &rewriter) const override {
+ Block *beforeBody = loop.getBeforeBody();
+ if (!llvm::hasSingleElement(beforeBody->without_terminator()))
+ return rewriter.notifyMatchFailure(loop, "Loop body must have single op");
+
+ auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
+ if (!cmp)
+ return rewriter.notifyMatchFailure(loop,
+ "Loop body must have single cmp op");
+
+ auto beforeTerm = cast<scf::ConditionOp>(beforeBody->getTerminator());
+ if (!llvm::hasSingleElement(cmp->getUses()) &&
+ beforeTerm.getCondition() == cmp.getResult())
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Expected single condiditon use: " << *cmp;
+ });
+
+ if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
+ return rewriter.notifyMatchFailure(loop, "Invalid args order");
+
+ using Pred = arith::CmpIPredicate;
+ auto predicate = cmp.getPredicate();
+ if (predicate != Pred::slt && predicate != Pred::sgt)
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
+ });
+
+ if (!checkIndexType(cmp, indexBitWidth))
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Expected index-like type: " << *cmp;
+ });
+
+ BlockArgument iterVar;
+ Value end;
+ DominanceInfo dom;
+ for (bool reverse : {false, true}) {
+ auto expectedPred = reverse ? Pred::sgt : Pred::slt;
+ if (cmp.getPredicate() != expectedPred)
+ continue;
+
+ auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
+ auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
+
+ auto blockArg = dyn_cast<BlockArgument>(arg1);
+ if (!blockArg || blockArg.getOwner() != beforeBody)
+ continue;
+
+ if (!dom.properlyDominates(arg2, loop))
+ continue;
+
+ iterVar = blockArg;
+ end = arg2;
+ break;
+ }
+
+ if (!iterVar)
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Unrecognized cmp form: " << *cmp;
+ });
+
+ if (!llvm::hasNItems(iterVar.getUses(), 2))
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Unrecognized iter var: " << iterVar;
+ });
+
+ Block *afterBody = loop.getAfterBody();
+ auto afterTerm = cast<scf::YieldOp>(afterBody->getTerminator());
+ auto argNumber = iterVar.getArgNumber();
+ auto afterTermIterArg = afterTerm.getResults()[argNumber];
+
+ auto iterVarAfter = afterBody->getArgument(argNumber);
+
+ Value step;
+ for (auto &use : iterVarAfter.getUses()) {
+ auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
+ if (!owner)
+ continue;
+
+ auto other =
+ (iterVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
+ if (!dom.properlyDominates(other, loop))
+ continue;
+
+ if (afterTermIterArg != owner.getResult())
+ continue;
+
+ step = other;
+ break;
+ }
+
+ if (!step)
+ return rewriter.notifyMatchFailure(loop,
+ "Didn't found suitable 'add' op");
+
+ auto begin = loop.getInits()[argNumber];
+
+ auto loc = loop.getLoc();
+ auto indexType = rewriter.getIndexType();
+ auto toIndex = [&](Value val) -> Value {
+ if (val.getType() != indexType)
+ return rewriter.create<arith::IndexCastOp>(loc, indexType, val);
+
+ return val;
+ };
+ begin = toIndex(begin);
+ end = toIndex(end);
+ step = toIndex(step);
+
+ llvm::SmallVector<Value> mapping;
+ mapping.reserve(loop.getInits().size());
+ for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
+ if (i == argNumber)
+ continue;
+
+ mapping.emplace_back(init);
+ }
+
+ auto emptyBuidler = [](OpBuilder &, Location, Value, ValueRange) {};
+ auto newLoop = rewriter.create<scf::ForOp>(loc, begin, end, step, mapping,
+ emptyBuidler);
+
+ Block *newBody = newLoop.getBody();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointToStart(newBody);
+ Value newIterVar = newBody->getArgument(0);
+ if (newIterVar.getType() != iterVar.getType())
+ newIterVar = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(),
+ newIterVar);
+
+ mapping.clear();
+ auto newArgs = newBody->getArguments();
+ for (auto i : llvm::seq<size_t>(0, newArgs.size())) {
+ if (i < argNumber) {
+ mapping.emplace_back(newArgs[i + 1]);
+ } else if (i == argNumber) {
+ Value arg = newArgs.front();
+ if (arg.getType() != iterVar.getType())
+ arg =
+ rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), arg);
+ mapping.emplace_back(arg);
+ } else {
+ mapping.emplace_back(newArgs[i]);
+ }
+ }
+
+ rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
+ mapping);
+
+ auto term = cast<scf::YieldOp>(newBody->getTerminator());
+
+ mapping.clear();
+ for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
+ if (i == argNumber)
+ continue;
+
+ mapping.emplace_back(arg);
+ }
+
+ rewriter.setInsertionPoint(term);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(term, mapping);
+
+ rewriter.setInsertionPointAfter(newLoop);
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
+ Value len = rewriter.create<arith::SubIOp>(loc, end, begin);
+ len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
+ len = rewriter.create<arith::DivSIOp>(loc, len, step);
+ len = rewriter.create<arith::SubIOp>(loc, len, one);
+ Value res = rewriter.create<arith::MulIOp>(loc, len, step);
+ res = rewriter.create<arith::AddIOp>(loc, begin, res);
+ if (res.getType() != iterVar.getType())
+ res = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), res);
+
+ mapping.clear();
+ llvm::append_range(mapping, newLoop.getResults());
+ mapping.insert(mapping.begin() + argNumber, res);
+ rewriter.replaceOp(loop, mapping);
+ return success();
+ }
+
+private:
+ unsigned indexBitWidth = 0;
+};
+
+struct SCFUpliftWhileToFor final
+ : impl::SCFUpliftWhileToForBase<SCFUpliftWhileToFor> {
+ using SCFUpliftWhileToForBase::SCFUpliftWhileToForBase;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ MLIRContext *ctx = op->getContext();
+ RewritePatternSet patterns(ctx);
+ mlir::scf::populateUpliftWhileToForPatterns(patterns, this->indexBitWidth);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
+ unsigned indexBitwidth) {
+ patterns.add<UpliftWhileOp>(patterns.getContext(), indexBitwidth);
+}
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
new file mode 100644
index 00000000000000..52a5c0f3cd6347
--- /dev/null
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -0,0 +1,162 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-uplift-while-to-for{index-bitwidth=64}))' -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+ %1 = arith.cmpi slt, %arg3, %arg1 : index
+ scf.condition(%1) %arg3 : index
+ } do {
+ ^bb0(%arg3: index):
+ "test.test1"(%arg3) : (index) -> ()
+ %added = arith.addi %arg3, %arg2 : index
+ "test.test2"(%added) : (index) -> ()
+ scf.yield %added : index
+ }
+ return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+// CHECK: "test.test1"(%[[I]]) : (index) -> ()
+// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+// CHECK: "test.test2"(%[[INC]]) : (index) -> ()
+// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+// CHECK: return %[[R7]] : index
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+ %1 = arith.cmpi sgt, %arg1, %arg3 : index
+ scf.condition(%1) %arg3 : index
+ } do {
+ ^bb0(%arg3: index):
+ "test.test1"(%arg3) : (index) -> ()
+ %added = arith.addi %arg3, %arg2 : index
+ "test.test2"(%added) : (index) -> ()
+ scf.yield %added : index
+ }
+ return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+// CHECK: "test.test1"(%[[I]]) : (index) -> ()
+// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+// CHECK: "test.test2"(%[[INC]]) : (index) -> ()
+// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+// CHECK: return %[[R7]] : index
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+ %1 = arith.cmpi slt, %arg3, %arg1 : index
+ scf.condition(%1) %arg3 : index
+ } do {
+ ^bb0(%arg3: index):
+ "test.test1"(%arg3) : (index) -> ()
+ %added = arith.addi %arg2, %arg3 : index
+ "test.test2"(%added) : (index) -> ()
+ scf.yield %added : index
+ }
+ return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+// CHECK: "test.test1"(%[[I]]) : (index) -> ()
+// CHECK: %[[INC:.*]] = arith.addi %[[STEP]], %[[I]] : index
+// CHECK: "test.test2"(%[[INC]]) : (index) -> ()
+// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+// CHECK: return %[[R7]] : index
+
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) {
+ %c1 = arith.constant 1 : i32
+ %c2 = arith.constant 2.0 : f32
+ %0:3 = scf.while (%arg4 = %c1, %arg3 = %arg0, %arg5 = %c2) : (i32, index, f32) -> (i32, index, f32) {
+ %1 = arith.cmpi slt, %arg3, %arg1 : index
+ scf.condition(%1) %arg4, %arg3, %arg5 : i32, index, f32
+ } do {
+ ^bb0(%arg4: i32, %arg3: index, %arg5: f32):
+ %1 = "test.test1"(%arg4) : (i32) -> i32
+ %added = arith.addi %arg3, %arg2 : index
+ %2 = "test.test2"(%arg5) : (f32) -> f32
+ scf.yield %1, %added, %2 : i32, index, f32
+ }
+ return %0#0, %0#2 : i32, f32
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> (i32, f32)
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]]
+// CHECK-SAME: iter_args(%[[ARG1:.*]] = %[[C1]], %[[ARG2:.*]] = %[[C2]]) -> (i32, f32) {
+// CHECK: %[[T1:.*]] = "test.test1"(%[[ARG1]]) : (i32) -> i32
+// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
+// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32
+// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32
+
+// -----
+
+func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
+ %0 = scf.while (%arg3 = %arg0) : (i64) -> (i64) {
+ %1 = arith.cmpi slt, %arg3, %arg1 : i64
+ scf.condition(%1) %arg3 : i64
+ } do {
+ ^bb0(%arg3: i64):
+ "test.test1"(%arg3) : (i64) -> ()
+ %added = arith.addi %arg3, %arg2 : i64
+ "test.test2"(%added) : (i64) -> ()
+ scf.yield %added : i64
+ }
+ return %0 : i64
+}
+
+// CHECK-LABEL: func @uplift_while
+// CHECK-SAME: (%[[BEGINI:.*]]: i64, %[[ENDI:.*]]: i64, %[[STEPI:.*]]: i64) -> i64
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[BEGIN:.*]] = arith.index_cast %[[BEGINI]] : i64 to index
+// CHECK: %[[END:.*]] = arith.index_cast %[[ENDI]] : i64 to index
+// CHECK: %[[STEP:.*]] = arith.index_cast %[[STEPI]] : i64 to index
+// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+// CHECK: %[[II:.*]] = arith.index_cast %[[I]] : index to i64
+// CHECK: "test.test1"(%[[II]]) : (i64) -> ()
+// CHECK: %[[INC:.*]] = arith.addi %[[II]], %[[STEPI]] : i64
+// CHECK: "test.test2"(%[[INC]]) : (i64) -> ()
+// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+// CHECK: %[[RES:.*]] = arith.index_cast %[[R7]] : index to i64
+// CHECK: return %[[RES]] : i64
|
@kiranchandramohan this is part of our uplifting pipeline, I've mentioned in #75314 (comment) |
Block *afterBody = loop.getAfterBody(); | ||
auto afterTerm = cast<scf::YieldOp>(afterBody->getTerminator()); | ||
auto argNumber = iterVar.getArgNumber(); | ||
auto afterTermIterArg = afterTerm.getResults()[argNumber]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need some kind of additional check here to make sure that the iter_args are forwarded from the "before" block to the "after" block in the same order?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's already enforced by if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
check.
Wouldn't this just be a good canonicalization for scf.while? |
I'm not sure if making this a canonicalization is good idea. I can imagine lowering pipeline looking like |
I don't quite see the issue though: that just means that such lowering cannot run a canonicalization in the middle. That seems expected to me though, in particular if someone design a lowering with "in-dialect" intermediate steps (like scf.for -> scf.while). |
I'm not against making it a part of canonicalization and it will work for our case, but I think there is still big chance that it can break downstream workflows. So, if we are really want to make it part of canonicalization I suggest to split it into 2 parts:
So, in case we need to revert it, we can revert only the canonicalizer part |
If `before` block args are directly forwarded to `scf.condition` make sure they are passes in the same order. This is needed for `scf.while` uplifting llvm#76108
Move non-side-effecting ops from `before` region if all their args are defined outside the loop. This is cleanup needed for `scf.while` -> `scf.for` uplifting llvm#76108 as it expects `before` block consisting of single `cmp` op.
ping |
It's not quite clear to me why we need a PSA for new canonicalization? |
As I said, I still believe there is high chance it breaking downstream workflows, so PSA won't hurt. I will update this PR to make it a test pass. |
35d2f4b
to
f16cc61
Compare
Renamed to test pass |
f16cc61
to
b4ca7e0
Compare
rebased |
b4ca7e0
to
c6fb869
Compare
If `before` block args are directly forwarded to `scf.condition` make sure they are passes in the same order. This is needed for `scf.while` uplifting llvm#76108
scf::ConditionOp beforeTerm = loop.getConditionOp(); | ||
if (!cmp->hasOneUse() && beforeTerm.getCondition() == cmp.getResult()) | ||
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { | ||
diag << "Expected single condiditon use: " << *cmp; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
diag << "Expected single condiditon use: " << *cmp; | |
diag << "Expected single condition use: " << *cmp; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
using namespace mlir; | ||
|
||
namespace { | ||
struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that we'd better decouple the transformation from the pattern.
That is: expose it as a C++ API FailureOr<scf::ForOp> upliftWhileToForLoop(scf::WhileOp loop);
exposed in a public header, and have the client code decide their integration point (this could connect to the Transform dialect for example).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Do we still want to make while-to-for uplifting part of while op canonicalization later?
@@ -154,4 +154,18 @@ def SCFForToWhileLoop : Pass<"scf-for-to-while"> { | |||
}]; | |||
} | |||
|
|||
def TestSCFUpliftWhileToFor : Pass<"test-scf-uplift-while-to-for"> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test passes should be in the test folder I believe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
c6fb869
to
9bfaf9e
Compare
Given that this is a piece of the uplift pipeline, I think it would be better to see the whole thing than eagerly install this as a canonicalizer in one step. Such loop transformation optimizations tend to grow into special things, and personally, I have come to dislike optimization pipelines that have action at a distance interactions with specific canonicalization patterns for proper functioning. It just makes them very hard to reason about. I'm +1 on being conservative about making the specific patterns actual canonicalizations until the whole pipeline is in place (unless if obvious/universal, which I'm not sure this one is). |
Also, it will need one more pattern (PR TBD) which probably doesn't qualify as canonicalization - duplicate and move ops from before block - to loop body and after, e.g:
to
|
ping |
Add uplifting from `scf.while` to `scf.for`. This uplifting expects a very specifi ops pattern: * `before` body consisting of single `arith.cmp` op * `after` body containing `arith.addi` * Iter var must be of type `index` or integer of specified width We also have a set of patterns to clenaup `scf.while` loops to get them close to the desired form, they will be added in separate PRs. This is part of upstreaming `numba-mlir` scf uplifting pipeline: `cf -> scf.while -> scf.for -> scf.parallel` Original code: https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/PromoteToParallel.cpp
9bfaf9e
to
e062c93
Compare
Add uplifting from `scf.while` to `scf.for`. This uplifting expects a very specific ops pattern: * `before` block consisting of single `arith.cmp` op * `after` block containing `arith.addi` We also have a set of patterns to cleanup `scf.while` loops to get them close to the desired form, they will be added in separate PRs. This is part of upstreaming `numba-mlir` scf uplifting pipeline: `cf -> scf.while -> scf.for -> scf.parallel` Original code: https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/PromoteToParallel.cpp
Add uplifting from
scf.while
toscf.for
.This uplifting expects a very specific ops pattern:
before
block consisting of singlearith.cmp
opafter
block containingarith.addi
We also have a set of patterns to cleanup
scf.while
loops to get them close to the desired form, they will be added in separate PRs.This is part of upstreaming
numba-mlir
scf uplifting pipeline:cf -> scf.while -> scf.for -> scf.parallel
Original code: https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/PromoteToParallel.cpp