Skip to content

[mlir] Add transformation to wrap scf::while in zero-trip-check #81050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 8, 2024

Conversation

pzread
Copy link
Member

@pzread pzread commented Feb 7, 2024

Add scf::wrapWhileLoopInZeroTripCheck to wrap scf while loop in zero-trip-check.

The transformation also rotates the while loop to avoid evaluating the loop condition twice, which might have side-effects.

@pzread pzread requested a review from dcaballe February 7, 2024 23:01
@pzread pzread changed the title [mlir] Add transformation to wrap scf::while with zero-trip-check [mlir] Add transformation to wrap scf::while in zero-trip-check Feb 7, 2024
@pzread pzread marked this pull request as ready for review February 7, 2024 23:03
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:scf labels Feb 7, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 7, 2024

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-scf

Author: Jerry Wu (pzread)

Changes

Add scf::wrapWhileLoopInZeroTripCheck to wrap scf while loop in zero-trip-check.

The transformation also rotates the while loop to avoid evaluating the loop condition twice, which might have side-effects.


Full diff: https://github.com/llvm/llvm-project/pull/81050.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h (+36)
  • (modified) mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp (+122)
  • (added) mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir (+40)
  • (modified) mlir/test/lib/Dialect/SCF/CMakeLists.txt (+1)
  • (added) mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp (+58)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index e91f9e4469ab7..1c1803113c232 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -30,6 +30,7 @@ namespace scf {
 class IfOp;
 class ForOp;
 class ParallelOp;
+class WhileOp;
 
 /// Fuses all adjacent scf.parallel operations with identical bounds and step
 /// into one scf.parallel operations. Uses a naive aliasing and dependency
@@ -181,6 +182,41 @@ FailureOr<ForOp> pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
                                  const PipeliningOption &options,
                                  bool *modifiedIR = nullptr);
 
+/// Create zero-trip-check around a `while` op and return the new loop op in the
+/// check. The while loop is rotated to avoid evaluating the condition twice. It
+/// turns:
+///
+///   scf.while (%arg0 = %init) : (i32) -> i64 {
+///     %val = .., %arg0 : i64
+///     %cond = arith.cmpi .., %arg0 : i32
+///     scf.condition(%cond) %val : i64
+///   } do {
+///   ^bb0(%arg1: i64):
+///     %next = .., %arg1 : i32
+///     scf.yield %next : i32
+///   }
+///
+///  into:
+///
+///   %pre_val = .., %init : i64
+///   %pre_cond = arith.cmpi .., %init : i32
+///   scf.if %pre_cond -> i64 {
+///     %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
+///       %next = .., %arg1 : i32
+///       %val = .., %next : i64
+///       %cond = arith.cmpi .., %next : i32
+///       scf.condition(%cond) %val : i64
+///     } do {
+///     ^bb0(%arg2: i64):
+///       %scf.yield %arg2 : i32
+///     }
+///     scf.yield %res : i64
+///   } else {
+///     scf.yield %pre_val : i64
+///   }
+FailureOr<WhileOp> wrapWhileLoopInZeroTripCheck(WhileOp whileOp,
+                                                RewriterBase &rewriter);
+
 } // namespace scf
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index fdaeb2fad9afa..e5494205e086a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   ParallelLoopTiling.cpp
   StructuralTypeConversions.cpp
   TileUsingInterface.cpp
+  WrapInZeroTripCheck.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
diff --git a/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp b/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp
new file mode 100644
index 0000000000000..0e1a15c2bdbda
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp
@@ -0,0 +1,122 @@
+//===- WrapInZeroTripCheck.cpp - Loop transforms to add zero-trip-check ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+/// Create zero-trip-check around a `while` op and return the new loop op in the
+/// check. The while loop is rotated to avoid evaluating the condition twice.
+///
+/// Given an example below:
+///
+///   scf.while (%arg0 = %init) : (i32) -> i64 {
+///     %val = .., %arg0 : i64
+///     %cond = arith.cmpi .., %arg0 : i32
+///     scf.condition(%cond) %val : i64
+///   } do {
+///   ^bb0(%arg1: i64):
+///     %next = .., %arg1 : i32
+///     scf.yield %next : i32
+///   }
+///
+/// First clone before block to the front of the loop:
+///
+///   %pre_val = .., %init : i64
+///   %pre_cond = arith.cmpi .., %init : i32
+///   scf.while (%arg0 = %init) : (i32) -> i64 {
+///     %val = .., %arg0 : i64
+///     %cond = arith.cmpi .., %arg0 : i32
+///     scf.condition(%cond) %val : i64
+///   } do {
+///   ^bb0(%arg1: i64):
+///     %next = .., %arg1 : i32
+///     scf.yield %next : i32
+///   }
+///
+/// Create `if` op with the condition, rotate and move the loop into the else
+/// branch:
+///
+///   %pre_val = .., %init : i64
+///   %pre_cond = arith.cmpi .., %init : i32
+///   scf.if %pre_cond -> i64 {
+///     %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
+///       // Original after block
+///       %next = .., %arg1 : i32
+///       // Original before block
+///       %val = .., %next : i64
+///       %cond = arith.cmpi .., %next : i32
+///       scf.condition(%cond) %val : i64
+///     } do {
+///     ^bb0(%arg2: i64):
+///       %scf.yield %arg2 : i32
+///     }
+///     scf.yield %res : i64
+///   } else {
+///     scf.yield %pre_val : i64
+///   }
+FailureOr<scf::WhileOp>
+mlir::scf::wrapWhileLoopInZeroTripCheck(scf::WhileOp whileOp,
+                                        RewriterBase &rewriter) {
+  IRMapping mapper;
+  Block *beforeBlock = whileOp.getBeforeBody();
+  // Clone before block before the loop for zero-trip-check.
+  for (auto [arg, init] :
+       llvm::zip_equal(beforeBlock->getArguments(), whileOp.getInits())) {
+    mapper.map(arg, init);
+  }
+  rewriter.setInsertionPoint(whileOp);
+  for (auto &op : *beforeBlock) {
+    if (isa<scf::ConditionOp>(op)) {
+      break;
+    }
+    // Safe to clone everything as in a single block all defs have been cloned
+    // and added to mapper in order.
+    rewriter.insert(op.clone(mapper));
+  }
+
+  auto condOp = whileOp.getConditionOp();
+  auto clonedCondition = mapper.lookupOrDefault(condOp.getCondition());
+  auto clonedCondArgs = llvm::map_to_vector(
+      condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(arg); });
+
+  // Create zero-trip-check and move the while loop in.
+  scf::WhileOp newLoopOp = nullptr;
+  auto ifOp = rewriter.create<scf::IfOp>(
+      whileOp->getLoc(), clonedCondition,
+      [&](OpBuilder &builder, Location loc) {
+        // Then runs the while loop.
+        newLoopOp = builder.create<scf::WhileOp>(
+            loc, whileOp.getResultTypes(), clonedCondArgs,
+            [&](OpBuilder &builder, Location loc, ValueRange args) {
+              // Rotate and move the loop body into before block.
+              auto newBlock = builder.getBlock();
+              rewriter.mergeBlocks(whileOp.getAfterBody(), newBlock, args);
+              auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
+              rewriter.mergeBlocks(whileOp.getBeforeBody(), newBlock,
+                                   yieldOp.getResults());
+              rewriter.eraseOp(yieldOp);
+            },
+            [&](OpBuilder &builder, Location loc, ValueRange args) {
+              // Pass-through values.
+              builder.create<scf::YieldOp>(loc, args);
+            });
+        builder.create<scf::YieldOp>(loc, newLoopOp.getResults());
+      },
+      [&](OpBuilder &builder, Location loc) {
+        // Else returns the results from zero-trip-check.
+        builder.create<scf::YieldOp>(loc, clonedCondArgs);
+      });
+
+  rewriter.replaceOp(whileOp, ifOp);
+
+  return newLoopOp;
+}
diff --git a/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir b/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
new file mode 100644
index 0000000000000..b87c6003ddd31
--- /dev/null
+++ b/mlir/test/Dialect/SCF/wrap-while-loop-in-zero-trip-check.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt %s -test-wrap-scf-while-loop-in-zero-trip-check -split-input-file  | FileCheck %s
+
+func.func @wrap_while_loop_in_zero_trip_check(%bound : i32) -> i32 {
+  %cst0 = arith.constant 0 : i32
+  %cst5 = arith.constant 5 : i32
+  %res:2 = scf.while (%iter = %cst0) : (i32) -> (i32, i32) {
+    %cond = arith.cmpi slt, %iter, %bound : i32
+    %inv = arith.addi %bound, %cst5 : i32
+    scf.condition(%cond) %iter, %inv : i32, i32
+  } do {
+  ^bb0(%arg1: i32, %arg2: i32):
+    %next = arith.addi %arg1, %arg2 : i32
+    scf.yield %next : i32
+  }
+  return %res#0 : i32
+}
+
+// CHECK-LABEL: func.func @wrap_while_loop_in_zero_trip_check(
+// CHECK-SAME:      %[[ARG0:.*]]: i32) -> i32 {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG:     %[[C5:.*]] = arith.constant 5 : i32
+// CHECK-DAG:     %[[PRE_COND:.*]] = arith.cmpi slt, %[[C0]], %[[ARG0]] : i32
+// CHECK-DAG:     %[[PRE_INV:.*]] = arith.addi %[[ARG0]], %[[C5]] : i32
+// CHECK:         %[[IF:.*]]:2 = scf.if %[[PRE_COND]] -> (i32, i32) {
+// CHECK:           %[[WHILE:.*]]:2 = scf.while (
+// CHECK-SAME:          %[[ARG1:.*]] = %[[C0]], %[[ARG2:.*]] = %[[PRE_INV]]
+// CHECK-SAME:      ) : (i32, i32) -> (i32, i32) {
+// CHECK:             %[[NEXT:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
+// CHECK:             %[[COND:.*]] = arith.cmpi slt, %[[NEXT]], %[[ARG0]] : i32
+// CHECK:             %[[INV:.*]] = arith.addi %[[ARG0]], %[[C5]] : i32
+// CHECK:             scf.condition(%[[COND]]) %[[NEXT]], %[[INV]] : i32, i32
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32):
+// CHECK:             scf.yield %[[ARG3]], %[[ARG4]] : i32, i32
+// CHECK:           }
+// CHECK:           scf.yield %[[WHILE]]#0, %[[WHILE]]#1 : i32, i32
+// CHECK:         } else {
+// CHECK:           scf.yield %[[C0]], %[[PRE_INV]] : i32, i32
+// CHECK:         }
+// CHECK:         return %[[IF]]#0 : i32
diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
index 22c2f2388de69..d93bd55915182 100644
--- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRSCFTestPasses
   TestLoopParametricTiling.cpp
   TestLoopUnrolling.cpp
   TestSCFUtils.cpp
+  TestSCFWrapInZeroTripCheck.cpp
   TestWhileOpBuilder.cpp
 
   EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp b/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
new file mode 100644
index 0000000000000..b51ef03288436
--- /dev/null
+++ b/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp
@@ -0,0 +1,58 @@
+//===- TestWrapInZeroTripCheck.cpp -- Passes to test SCF zero-trip-check --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the passes to test wrap-in-zero-trip-check transforms on
+// SCF loop ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestWrapWhileLoopInZeroTripCheck
+    : public PassWrapper<TestWrapWhileLoopInZeroTripCheck,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrapWhileLoopInZeroTripCheck)
+
+  StringRef getArgument() const final {
+    return "test-wrap-scf-while-loop-in-zero-trip-check";
+  }
+  StringRef getDescription() const final {
+    return "test scf::wrapWhileLoopInZeroTripCheck";
+  }
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+    MLIRContext *context = &getContext();
+    IRRewriter rewriter(context);
+    func.walk([&](scf::WhileOp op) {
+      auto result = scf::wrapWhileLoopInZeroTripCheck(op, rewriter);
+      if (failed(result)) {
+        // Ignore not implemented failure in tests. The expected output should
+        // catch problems (e.g. transformation doesn't happen).
+      }
+    });
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestSCFWrapInZeroTripCheckPasses() {
+  PassRegistration<TestWrapWhileLoopInZeroTripCheck>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 428bdd9691e09..8ca16f17f66e8 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -129,6 +129,7 @@ void registerTestPreparationPassWithAllowedMemrefResults();
 void registerTestRecursiveTypesPass();
 void registerTestSCFUtilsPass();
 void registerTestSCFWhileOpBuilderPass();
+void registerTestSCFWrapInZeroTripCheckPasses();
 void registerTestShapeMappingPass();
 void registerTestSliceAnalysisPass();
 void registerTestTensorCopyInsertionPass();
@@ -251,6 +252,7 @@ void registerTestPasses() {
   mlir::test::registerTestRecursiveTypesPass();
   mlir::test::registerTestSCFUtilsPass();
   mlir::test::registerTestSCFWhileOpBuilderPass();
+  mlir::test::registerTestSCFWrapInZeroTripCheckPasses();
   mlir::test::registerTestShapeMappingPass();
   mlir::test::registerTestSliceAnalysisPass();
   mlir::test::registerTestTensorCopyInsertionPass();

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a first look. LG! A few comments

@pzread pzread requested a review from dcaballe February 8, 2024 18:48
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the feedback!

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know the mechanics of while loop to comment without looking deeper, but thanks for making its own entry point.

@@ -183,8 +183,12 @@ FailureOr<ForOp> pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
bool *modifiedIR = nullptr);

/// Create zero-trip-check around a `while` op and return the new loop op in the
/// check. The while loop is rotated to avoid evaluating the condition twice. It
/// turns:
/// check. The while loop is rotated to avoid evaluating the condition twice
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to future proof, maybe a good idea to return the if generated as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO users can look into the parent op for the if they need.

And in do {} while (...) case we actually don't generate the if (and don't rotate the0 loop) because everything in the before block is always executed (there is a parameter to control this behavior)

I don't have a strong opinion on how this method should be designed, but only returning the loop seems to be enough for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to land this as it is first. We can re-iterate on interface if needed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extending it as needed sounds good to me!

@pzread pzread merged commit f720150 into llvm:main Feb 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:scf mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants