-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Add transformation to wrap scf::while in zero-trip-check #81050
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-scf Author: Jerry Wu (pzread) ChangesAdd The transformation also rotates the while loop to avoid evaluating the loop condition twice, which might have side-effects. Full diff: https://github.com/llvm/llvm-project/pull/81050.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index e91f9e4469ab7..1c1803113c232 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -30,6 +30,7 @@ namespace scf {
class IfOp;
class ForOp;
class ParallelOp;
+class WhileOp;
/// Fuses all adjacent scf.parallel operations with identical bounds and step
/// into one scf.parallel operations. Uses a naive aliasing and dependency
@@ -181,6 +182,41 @@ FailureOr<ForOp> pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
const PipeliningOption &options,
bool *modifiedIR = nullptr);
+/// Create zero-trip-check around a `while` op and return the new loop op in the
+/// check. The while loop is rotated to avoid evaluating the condition twice. It
+/// turns:
+///
+/// scf.while (%arg0 = %init) : (i32) -> i64 {
+/// %val = .., %arg0 : i64
+/// %cond = arith.cmpi .., %arg0 : i32
+/// scf.condition(%cond) %val : i64
+/// } do {
+/// ^bb0(%arg1: i64):
+/// %next = .., %arg1 : i32
+/// scf.yield %next : i32
+/// }
+///
+/// into:
+///
+/// %pre_val = .., %init : i64
+/// %pre_cond = arith.cmpi .., %init : i32
+/// scf.if %pre_cond -> i64 {
+/// %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
+/// %next = .., %arg1 : i32
+/// %val = .., %next : i64
+/// %cond = arith.cmpi .., %next : i32
+/// scf.condition(%cond) %val : i64
+/// } do {
+/// ^bb0(%arg2: i64):
+/// %scf.yield %arg2 : i32
+/// }
+/// scf.yield %res : i64
+/// } else {
+/// scf.yield %pre_val : i64
+/// }
+FailureOr<WhileOp> wrapWhileLoopInZeroTripCheck(WhileOp whileOp,
+ RewriterBase &rewriter);
+
} // namespace scf
} // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index fdaeb2fad9afa..e5494205e086a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
ParallelLoopTiling.cpp
StructuralTypeConversions.cpp
TileUsingInterface.cpp
+ WrapInZeroTripCheck.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
diff --git a/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp b/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp
new file mode 100644
index 0000000000000..0e1a15c2bdbda
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp
@@ -0,0 +1,122 @@
+//===- WrapInZeroTripCheck.cpp - Loop transforms to add zero-trip-check ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+/// Create zero-trip-check around a `while` op and return the new loop op in the
+/// check. The while loop is rotated to avoid evaluating the condition twice.
+///
+/// Given an example below:
+///
+/// scf.while (%arg0 = %init) : (i32) -> i64 {
+/// %val = .., %arg0 : i64
+/// %cond = arith.cmpi .., %arg0 : i32
+/// scf.condition(%cond) %val : i64
+/// } do {
+/// ^bb0(%arg1: i64):
+/// %next = .., %arg1 : i32
+/// scf.yield %next : i32
+/// }
+///
+/// First clone before block to the front of the loop:
+///
+/// %pre_val = .., %init : i64
+/// %pre_cond = arith.cmpi .., %init : i32
+/// scf.while (%arg0 = %init) : (i32) -> i64 {
+/// %val = .., %arg0 : i64
+/// %cond = arith.cmpi .., %arg0 : i32
+/// scf.condition(%cond) %val : i64
+/// } do {
+/// ^bb0(%arg1: i64):
+/// %next = .., %arg1 : i32
+/// scf.yield %next : i32
+/// }
+///
+/// Create `if` op with the condition, rotate and move the loop into the else
+/// branch:
+///
+/// %pre_val = .., %init : i64
+/// %pre_cond = arith.cmpi .., %init : i32
+/// scf.if %pre_cond -> i64 {
+/// %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
+/// // Original after block
+/// %next = .., %arg1 : i32
+/// // Original before block
+/// %val = .., %next : i64
+/// %cond = arith.cmpi .., %next : i32
+/// scf.condition(%cond) %val : i64
+/// } do {
+/// ^bb0(%arg2: i64):
+/// %scf.yield %arg2 : i32
+/// }
+/// scf.yield %res : i64
+/// } else {
+/// scf.yield %pre_val : i64
+/// }
+FailureOr<scf::WhileOp>
+mlir::scf::wrapWhileLoopInZeroTripCheck(scf::WhileOp whileOp,
+ RewriterBase &rewriter) {
+ IRMapping mapper;
+ Block *beforeBlock = whileOp.getBeforeBody();
+ // Clone before block before the loop for zero-trip-check.
+ for (auto [arg, init] :
+ llvm::zip_equal(beforeBlock->getArguments(), whileOp.getInits())) {
+ mapper.map(arg, init);
+ }
+ rewriter.setInsertionPoint(whileOp);
+ for (auto &op : *beforeBlock) {
+ if (isa<scf::ConditionOp>(op)) {
+ break;
+ }
+ // Safe to clone everything as in a single block all defs have been cloned
+ // and added to mapper in order.
+ rewriter.insert(op.clone(mapper));
+ }
+
+ auto condOp = whileOp.getConditionOp();
+ auto clonedCondition = mapper.lookupOrDefault(condOp.getCondition());
+ auto clonedCondArgs = llvm::map_to_vector(
+ condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(arg); });
+
+ // Create zero-trip-check and move the while loop in.
+ scf::WhileOp newLoopOp = nullptr;
+ auto ifOp = rewriter.create<scf::IfOp>(
+ whileOp->getLoc(), clonedCondition,
+ [&](OpBuilder &builder, Location loc) {
+ // Then runs the while loop.
+ newLoopOp = builder.create<scf::WhileOp>(
+ loc, whileOp.getResultTypes(), clonedCondArgs,
+ [&](OpBuilder &builder, Location loc, ValueRange args) {
+ // Rotate and move the loop body into before block.
+ auto newBlock = builder.getBlock();
+ rewriter.mergeBlocks(whileOp.getAfterBody(), newBlock, args);
+ auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
+ rewriter.mergeBlocks(whileOp.getBeforeBody(), newBlock,
+ yieldOp.getResults());
+ rewriter.eraseOp(yieldOp);
+ },
+ [&](OpBuilder &builder, Location loc, ValueRange args) {
+ // Pass-through values.
+ builder.create<scf::YieldOp>(loc, args);
+ });
+ builder.create<scf::YieldOp>(loc, newLoopOp.getResults());
+ },
+ [&](OpBuilder &builder, Location loc) {
+ // Else returns the results from zero-trip-check.
+ builder.create<scf::YieldOp>(loc, clonedCondArgs);
+ });
+
+ rewriter.replaceOp(whileOp, ifOp);
+
+ return newLoopOp;
+}
diff --git a/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir b/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
new file mode 100644
index 0000000000000..b87c6003ddd31
--- /dev/null
+++ b/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt %s -test-wrap-scf-while-loop-in-zero-trip-check -split-input-file | FileCheck %s
+
+func.func @wrap_while_loop_in_zero_trip_check(%bound : i32) -> i32 {
+ %cst0 = arith.constant 0 : i32
+ %cst5 = arith.constant 5 : i32
+ %res:2 = scf.while (%iter = %cst0) : (i32) -> (i32, i32) {
+ %cond = arith.cmpi slt, %iter, %bound : i32
+ %inv = arith.addi %bound, %cst5 : i32
+ scf.condition(%cond) %iter, %inv : i32, i32
+ } do {
+ ^bb0(%arg1: i32, %arg2: i32):
+ %next = arith.addi %arg1, %arg2 : i32
+ scf.yield %next : i32
+ }
+ return %res#0 : i32
+}
+
+// CHECK-LABEL: func.func @wrap_while_loop_in_zero_trip_check(
+// CHECK-SAME: %[[ARG0:.*]]: i32) -> i32 {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK-DAG: %[[PRE_COND:.*]] = arith.cmpi slt, %[[C0]], %[[ARG0]] : i32
+// CHECK-DAG: %[[PRE_INV:.*]] = arith.addi %[[ARG0]], %[[C5]] : i32
+// CHECK: %[[IF:.*]]:2 = scf.if %[[PRE_COND]] -> (i32, i32) {
+// CHECK: %[[WHILE:.*]]:2 = scf.while (
+// CHECK-SAME: %[[ARG1:.*]] = %[[C0]], %[[ARG2:.*]] = %[[PRE_INV]]
+// CHECK-SAME: ) : (i32, i32) -> (i32, i32) {
+// CHECK: %[[NEXT:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
+// CHECK: %[[COND:.*]] = arith.cmpi slt, %[[NEXT]], %[[ARG0]] : i32
+// CHECK: %[[INV:.*]] = arith.addi %[[ARG0]], %[[C5]] : i32
+// CHECK: scf.condition(%[[COND]]) %[[NEXT]], %[[INV]] : i32, i32
+// CHECK: } do {
+// CHECK: ^bb0(%[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32):
+// CHECK: scf.yield %[[ARG3]], %[[ARG4]] : i32, i32
+// CHECK: }
+// CHECK: scf.yield %[[WHILE]]#0, %[[WHILE]]#1 : i32, i32
+// CHECK: } else {
+// CHECK: scf.yield %[[C0]], %[[PRE_INV]] : i32, i32
+// CHECK: }
+// CHECK: return %[[IF]]#0 : i32
diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
index 22c2f2388de69..d93bd55915182 100644
--- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRSCFTestPasses
TestLoopParametricTiling.cpp
TestLoopUnrolling.cpp
TestSCFUtils.cpp
+ TestSCFWrapInZeroTripCheck.cpp
TestWhileOpBuilder.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp b/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
new file mode 100644
index 0000000000000..b51ef03288436
--- /dev/null
+++ b/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
@@ -0,0 +1,58 @@
+//===- TestWrapInZeroTripCheck.cpp -- Passes to test SCF zero-trip-check --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the passes to test wrap-in-zero-trip-check transforms on
+// SCF loop ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestWrapWhileLoopInZeroTripCheck
+ : public PassWrapper<TestWrapWhileLoopInZeroTripCheck,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrapWhileLoopInZeroTripCheck)
+
+ StringRef getArgument() const final {
+ return "test-wrap-scf-while-loop-in-zero-trip-check";
+ }
+ StringRef getDescription() const final {
+ return "test scf::wrapWhileLoopInZeroTripCheck";
+ }
+
+ void runOnOperation() override {
+ func::FuncOp func = getOperation();
+ MLIRContext *context = &getContext();
+ IRRewriter rewriter(context);
+ func.walk([&](scf::WhileOp op) {
+ auto result = scf::wrapWhileLoopInZeroTripCheck(op, rewriter);
+ if (failed(result)) {
+ // Ignore not implemented failure in tests. The expected output should
+ // catch problems (e.g. transformation doesn't happen).
+ }
+ });
+ }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestSCFWrapInZeroTripCheckPasses() {
+ PassRegistration<TestWrapWhileLoopInZeroTripCheck>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 428bdd9691e09..8ca16f17f66e8 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -129,6 +129,7 @@ void registerTestPreparationPassWithAllowedMemrefResults();
void registerTestRecursiveTypesPass();
void registerTestSCFUtilsPass();
void registerTestSCFWhileOpBuilderPass();
+void registerTestSCFWrapInZeroTripCheckPasses();
void registerTestShapeMappingPass();
void registerTestSliceAnalysisPass();
void registerTestTensorCopyInsertionPass();
@@ -251,6 +252,7 @@ void registerTestPasses() {
mlir::test::registerTestRecursiveTypesPass();
mlir::test::registerTestSCFUtilsPass();
mlir::test::registerTestSCFWhileOpBuilderPass();
+ mlir::test::registerTestSCFWrapInZeroTripCheckPasses();
mlir::test::registerTestShapeMappingPass();
mlir::test::registerTestSliceAnalysisPass();
mlir::test::registerTestTensorCopyInsertionPass();
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took a first look. LG! A few comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing the feedback!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont know the mechanics of while loop to comment without looking deeper, but thanks for making its own entry point.
@@ -183,8 +183,12 @@ FailureOr<ForOp> pipelineForLoop(RewriterBase &rewriter, ForOp forOp, | |||
bool *modifiedIR = nullptr); | |||
|
|||
/// Create zero-trip-check around a `while` op and return the new loop op in the | |||
/// check. The while loop is rotated to avoid evaluating the condition twice. It | |||
/// turns: | |||
/// check. The while loop is rotated to avoid evaluating the condition twice |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to future proof, maybe a good idea to return the if
generated as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO users can look into the parent op for the if
they need.
And in do {} while (...)
case we actually don't generate the if
(and don't rotate the0 loop) because everything in the before block is always executed (there is a parameter to control this behavior)
I don't have a strong opinion on how this method should be designed, but only returning the loop seems to be enough for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to land this as it is first. We can re-iterate on interface if needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extending it as needed sounds good to me!
Add
scf::wrapWhileLoopInZeroTripCheck
to wrap scf while loop in zero-trip-check.The transformation also rotates the while loop to avoid evaluating the loop condition twice, which might have side-effects.