Skip to content

[mlir][scf] Uplift scf.while to scf.for #76108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 15, 2024
Merged

Conversation

Hardcode84
Copy link
Contributor

@Hardcode84 Hardcode84 commented Dec 20, 2023

Add uplifting from scf.while to scf.for.

This uplifting expects a very specific ops pattern:

  • before block consisting of single arith.cmp op
  • after block containing arith.addi

We also have a set of patterns to cleanup scf.while loops to get them close to the desired form, they will be added in separate PRs.

This is part of upstreaming numba-mlir scf uplifting pipeline: cf -> scf.while -> scf.for -> scf.parallel

Original code: https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/PromoteToParallel.cpp

@llvmbot
Copy link
Member

llvmbot commented Dec 20, 2023

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Add uplifting from scf.while to scf.for.

This uplifting expects a very specific ops pattern:

  • before block consisting of single arith.cmp op
  • after block containing arith.addi
  • Iter var must be of type index or integer of specified width

We also have a set of patterns to cleanup scf.while loops to get them close to the desired form, they will be added in separate PRs.

This is part of upstreaming numba-mlir scf uplifting pipeline: cf -> scf.while -> scf.for -> scf.parallel

Original code: https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/PromoteToParallel.cpp


Full diff: https://github.com/llvm/llvm-project/pull/76108.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.td (+16)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h (+4)
  • (modified) mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp (+250)
  • (added) mlir/test/Dialect/SCF/uplift-while.mlir (+162)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 350611ad86873d..ec28bb0b8b8aa8 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -154,4 +154,20 @@ def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   }];
 }
 
+def SCFUpliftWhileToFor : Pass<"scf-uplift-while-to-for"> {
+  let summary = "Uplift scf.while ops to scf.for";
+  let description = [{
+    This pass tries to uplift `scf.while` ops to `scf.for` if they have a
+    compatible form. `scf.while` are left unchanged if uplifting is not
+    possible.
+  }];
+
+  let options = [
+    Option<"indexBitWidth", "index-bitwidth", "unsigned",
+           /*default=*/"64",
+           "Bitwidth of index type.">,
+  ];
+ }
+
+
 #endif // MLIR_DIALECT_SCF_PASSES
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 5c0d5643c01986..9f3cdd93071ea9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -79,6 +79,10 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
 /// loop bounds and loop steps are canonicalized.
 void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
 
+/// Populate patterns to uplift `scf.while` ops to `scf.for`.
+void populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
+                                      unsigned indexBitwidth);
+
 } // namespace scf
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index fdaeb2fad9afa4..7643bab80a1308 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   ParallelLoopTiling.cpp
   StructuralTypeConversions.cpp
   TileUsingInterface.cpp
+  UpliftWhileToFor.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
new file mode 100644
index 00000000000000..cd16b622504953
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -0,0 +1,250 @@
+//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.WhileOp's into SCF.ForOp's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFUPLIFTWHILETOFOR
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+static bool checkIndexType(arith::CmpIOp op, unsigned indexBitWidth) {
+  auto type = op.getLhs().getType();
+  if (isa<mlir::IndexType>(type))
+    return true;
+
+  if (type.isSignlessInteger(indexBitWidth))
+    return true;
+
+  return false;
+}
+
+namespace {
+struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
+  UpliftWhileOp(MLIRContext *context, unsigned indexBitWidth_)
+      : OpRewritePattern<scf::WhileOp>(context), indexBitWidth(indexBitWidth_) {
+  }
+
+  LogicalResult matchAndRewrite(scf::WhileOp loop,
+                                PatternRewriter &rewriter) const override {
+    Block *beforeBody = loop.getBeforeBody();
+    if (!llvm::hasSingleElement(beforeBody->without_terminator()))
+      return rewriter.notifyMatchFailure(loop, "Loop body must have single op");
+
+    auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
+    if (!cmp)
+      return rewriter.notifyMatchFailure(loop,
+                                         "Loop body must have single cmp op");
+
+    auto beforeTerm = cast<scf::ConditionOp>(beforeBody->getTerminator());
+    if (!llvm::hasSingleElement(cmp->getUses()) &&
+        beforeTerm.getCondition() == cmp.getResult())
+      return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+        diag << "Expected single condiditon use: " << *cmp;
+      });
+
+    if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
+      return rewriter.notifyMatchFailure(loop, "Invalid args order");
+
+    using Pred = arith::CmpIPredicate;
+    auto predicate = cmp.getPredicate();
+    if (predicate != Pred::slt && predicate != Pred::sgt)
+      return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+        diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
+      });
+
+    if (!checkIndexType(cmp, indexBitWidth))
+      return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+        diag << "Expected index-like type: " << *cmp;
+      });
+
+    BlockArgument iterVar;
+    Value end;
+    DominanceInfo dom;
+    for (bool reverse : {false, true}) {
+      auto expectedPred = reverse ? Pred::sgt : Pred::slt;
+      if (cmp.getPredicate() != expectedPred)
+        continue;
+
+      auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
+      auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
+
+      auto blockArg = dyn_cast<BlockArgument>(arg1);
+      if (!blockArg || blockArg.getOwner() != beforeBody)
+        continue;
+
+      if (!dom.properlyDominates(arg2, loop))
+        continue;
+
+      iterVar = blockArg;
+      end = arg2;
+      break;
+    }
+
+    if (!iterVar)
+      return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+        diag << "Unrecognized cmp form: " << *cmp;
+      });
+
+    if (!llvm::hasNItems(iterVar.getUses(), 2))
+      return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+        diag << "Unrecognized iter var: " << iterVar;
+      });
+
+    Block *afterBody = loop.getAfterBody();
+    auto afterTerm = cast<scf::YieldOp>(afterBody->getTerminator());
+    auto argNumber = iterVar.getArgNumber();
+    auto afterTermIterArg = afterTerm.getResults()[argNumber];
+
+    auto iterVarAfter = afterBody->getArgument(argNumber);
+
+    Value step;
+    for (auto &use : iterVarAfter.getUses()) {
+      auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
+      if (!owner)
+        continue;
+
+      auto other =
+          (iterVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
+      if (!dom.properlyDominates(other, loop))
+        continue;
+
+      if (afterTermIterArg != owner.getResult())
+        continue;
+
+      step = other;
+      break;
+    }
+
+    if (!step)
+      return rewriter.notifyMatchFailure(loop,
+                                         "Didn't found suitable 'add' op");
+
+    auto begin = loop.getInits()[argNumber];
+
+    auto loc = loop.getLoc();
+    auto indexType = rewriter.getIndexType();
+    auto toIndex = [&](Value val) -> Value {
+      if (val.getType() != indexType)
+        return rewriter.create<arith::IndexCastOp>(loc, indexType, val);
+
+      return val;
+    };
+    begin = toIndex(begin);
+    end = toIndex(end);
+    step = toIndex(step);
+
+    llvm::SmallVector<Value> mapping;
+    mapping.reserve(loop.getInits().size());
+    for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
+      if (i == argNumber)
+        continue;
+
+      mapping.emplace_back(init);
+    }
+
+    auto emptyBuidler = [](OpBuilder &, Location, Value, ValueRange) {};
+    auto newLoop = rewriter.create<scf::ForOp>(loc, begin, end, step, mapping,
+                                               emptyBuidler);
+
+    Block *newBody = newLoop.getBody();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPointToStart(newBody);
+    Value newIterVar = newBody->getArgument(0);
+    if (newIterVar.getType() != iterVar.getType())
+      newIterVar = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(),
+                                                       newIterVar);
+
+    mapping.clear();
+    auto newArgs = newBody->getArguments();
+    for (auto i : llvm::seq<size_t>(0, newArgs.size())) {
+      if (i < argNumber) {
+        mapping.emplace_back(newArgs[i + 1]);
+      } else if (i == argNumber) {
+        Value arg = newArgs.front();
+        if (arg.getType() != iterVar.getType())
+          arg =
+              rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), arg);
+        mapping.emplace_back(arg);
+      } else {
+        mapping.emplace_back(newArgs[i]);
+      }
+    }
+
+    rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
+                               mapping);
+
+    auto term = cast<scf::YieldOp>(newBody->getTerminator());
+
+    mapping.clear();
+    for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
+      if (i == argNumber)
+        continue;
+
+      mapping.emplace_back(arg);
+    }
+
+    rewriter.setInsertionPoint(term);
+    rewriter.replaceOpWithNewOp<scf::YieldOp>(term, mapping);
+
+    rewriter.setInsertionPointAfter(newLoop);
+    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
+    Value len = rewriter.create<arith::SubIOp>(loc, end, begin);
+    len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
+    len = rewriter.create<arith::DivSIOp>(loc, len, step);
+    len = rewriter.create<arith::SubIOp>(loc, len, one);
+    Value res = rewriter.create<arith::MulIOp>(loc, len, step);
+    res = rewriter.create<arith::AddIOp>(loc, begin, res);
+    if (res.getType() != iterVar.getType())
+      res = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), res);
+
+    mapping.clear();
+    llvm::append_range(mapping, newLoop.getResults());
+    mapping.insert(mapping.begin() + argNumber, res);
+    rewriter.replaceOp(loop, mapping);
+    return success();
+  }
+
+private:
+  unsigned indexBitWidth = 0;
+};
+
+struct SCFUpliftWhileToFor final
+    : impl::SCFUpliftWhileToForBase<SCFUpliftWhileToFor> {
+  using SCFUpliftWhileToForBase::SCFUpliftWhileToForBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
+    RewritePatternSet patterns(ctx);
+    mlir::scf::populateUpliftWhileToForPatterns(patterns, this->indexBitWidth);
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
+                                                 unsigned indexBitwidth) {
+  patterns.add<UpliftWhileOp>(patterns.getContext(), indexBitwidth);
+}
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
new file mode 100644
index 00000000000000..52a5c0f3cd6347
--- /dev/null
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -0,0 +1,162 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-uplift-while-to-for{index-bitwidth=64}))' -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : index
+    scf.condition(%1) %arg3 : index
+  } do {
+  ^bb0(%arg3: index):
+    "test.test1"(%arg3) : (index) -> ()
+    %added = arith.addi %arg3, %arg2 : index
+    "test.test2"(%added) : (index) -> ()
+    scf.yield %added : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     "test.test1"(%[[I]]) : (index) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+//       CHECK:     "test.test2"(%[[INC]]) : (index) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     return %[[R7]] : index
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+    %1 = arith.cmpi sgt, %arg1, %arg3 : index
+    scf.condition(%1) %arg3 : index
+  } do {
+  ^bb0(%arg3: index):
+    "test.test1"(%arg3) : (index) -> ()
+    %added = arith.addi %arg3, %arg2 : index
+    "test.test2"(%added) : (index) -> ()
+    scf.yield %added : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     "test.test1"(%[[I]]) : (index) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+//       CHECK:     "test.test2"(%[[INC]]) : (index) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     return %[[R7]] : index
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : index
+    scf.condition(%1) %arg3 : index
+  } do {
+  ^bb0(%arg3: index):
+    "test.test1"(%arg3) : (index) -> ()
+    %added = arith.addi %arg2, %arg3 : index
+    "test.test2"(%added) : (index) -> ()
+    scf.yield %added : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     "test.test1"(%[[I]]) : (index) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[STEP]], %[[I]] : index
+//       CHECK:     "test.test2"(%[[INC]]) : (index) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     return %[[R7]] : index
+
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) {
+  %c1 = arith.constant 1 : i32
+  %c2 = arith.constant 2.0 : f32
+  %0:3 = scf.while (%arg4 = %c1, %arg3 = %arg0, %arg5 = %c2) : (i32, index, f32) -> (i32, index, f32) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : index
+    scf.condition(%1) %arg4, %arg3, %arg5 : i32, index, f32
+  } do {
+  ^bb0(%arg4: i32, %arg3: index, %arg5: f32):
+    %1 = "test.test1"(%arg4) : (i32) -> i32
+    %added = arith.addi %arg3, %arg2 : index
+    %2 = "test.test2"(%arg5) : (f32) -> f32
+    scf.yield %1, %added, %2 : i32, index, f32
+  }
+  return %0#0, %0#2 : i32, f32
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> (i32, f32)
+//   CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : i32
+//   CHECK-DAG:     %[[C2:.*]] = arith.constant 2.000000e+00 : f32
+//       CHECK:     %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]]
+//  CHECK-SAME:     iter_args(%[[ARG1:.*]] = %[[C1]], %[[ARG2:.*]] = %[[C2]]) -> (i32, f32) {
+//       CHECK:     %[[T1:.*]] = "test.test1"(%[[ARG1]]) : (i32) -> i32
+//       CHECK:     %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
+//       CHECK:     scf.yield %[[T1]], %[[T2]] : i32, f32
+//       CHECK:     return %[[RES]]#0, %[[RES]]#1 : i32, f32
+
+// -----
+
+func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
+  %0 = scf.while (%arg3 = %arg0) : (i64) -> (i64) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : i64
+    scf.condition(%1) %arg3 : i64
+  } do {
+  ^bb0(%arg3: i64):
+    "test.test1"(%arg3) : (i64) -> ()
+    %added = arith.addi %arg3, %arg2 : i64
+    "test.test2"(%added) : (i64) -> ()
+    scf.yield %added : i64
+  }
+  return %0 : i64
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGINI:.*]]: i64, %[[ENDI:.*]]: i64, %[[STEPI:.*]]: i64) -> i64
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     %[[BEGIN:.*]] = arith.index_cast %[[BEGINI]] : i64 to index
+//       CHECK:     %[[END:.*]] = arith.index_cast %[[ENDI]] : i64 to index
+//       CHECK:     %[[STEP:.*]] = arith.index_cast %[[STEPI]] : i64 to index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     %[[II:.*]] = arith.index_cast %[[I]] : index to i64
+//       CHECK:     "test.test1"(%[[II]]) : (i64) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[II]], %[[STEPI]] : i64
+//       CHECK:     "test.test2"(%[[INC]]) : (i64) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     %[[RES:.*]] = arith.index_cast %[[R7]] : index to i64
+//       CHECK:     return %[[RES]] : i64

@llvmbot
Copy link
Member

llvmbot commented Dec 20, 2023

@llvm/pr-subscribers-mlir-scf

Author: Ivan Butygin (Hardcode84)

Changes

Add uplifting from scf.while to scf.for.

This uplifting expects a very specific ops pattern:

  • before block consisting of single arith.cmp op
  • after block containing arith.addi
  • Iter var must be of type index or integer of specified width

We also have a set of patterns to cleanup scf.while loops to get them close to the desired form, they will be added in separate PRs.

This is part of upstreaming numba-mlir scf uplifting pipeline: cf -&gt; scf.while -&gt; scf.for -&gt; scf.parallel

Original code: https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/PromoteToParallel.cpp


Full diff: https://github.com/llvm/llvm-project/pull/76108.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.td (+16)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h (+4)
  • (modified) mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp (+250)
  • (added) mlir/test/Dialect/SCF/uplift-while.mlir (+162)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 350611ad86873d..ec28bb0b8b8aa8 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -154,4 +154,20 @@ def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   }];
 }
 
+def SCFUpliftWhileToFor : Pass<"scf-uplift-while-to-for"> {
+  let summary = "Uplift scf.while ops to scf.for";
+  let description = [{
+    This pass tries to uplift `scf.while` ops to `scf.for` if they have a
+    compatible form. `scf.while` are left unchanged if uplifting is not
+    possible.
+  }];
+
+  let options = [
+    Option<"indexBitWidth", "index-bitwidth", "unsigned",
+           /*default=*/"64",
+           "Bitwidth of index type.">,
+  ];
+ }
+
+
 #endif // MLIR_DIALECT_SCF_PASSES
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 5c0d5643c01986..9f3cdd93071ea9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -79,6 +79,10 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
 /// loop bounds and loop steps are canonicalized.
 void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
 
+/// Populate patterns to uplift `scf.while` ops to `scf.for`.
+void populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
+                                      unsigned indexBitwidth);
+
 } // namespace scf
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index fdaeb2fad9afa4..7643bab80a1308 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   ParallelLoopTiling.cpp
   StructuralTypeConversions.cpp
   TileUsingInterface.cpp
+  UpliftWhileToFor.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
new file mode 100644
index 00000000000000..cd16b622504953
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -0,0 +1,250 @@
+//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.WhileOp's into SCF.ForOp's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFUPLIFTWHILETOFOR
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+static bool checkIndexType(arith::CmpIOp op, unsigned indexBitWidth) {
+  auto type = op.getLhs().getType();
+  if (isa<mlir::IndexType>(type))
+    return true;
+
+  if (type.isSignlessInteger(indexBitWidth))
+    return true;
+
+  return false;
+}
+
+namespace {
+struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
+  UpliftWhileOp(MLIRContext *context, unsigned indexBitWidth_)
+      : OpRewritePattern<scf::WhileOp>(context), indexBitWidth(indexBitWidth_) {
+  }
+
+  LogicalResult matchAndRewrite(scf::WhileOp loop,
+                                PatternRewriter &rewriter) const override {
+    Block *beforeBody = loop.getBeforeBody();
+    if (!llvm::hasSingleElement(beforeBody->without_terminator()))
+      return rewriter.notifyMatchFailure(loop, "Loop body must have single op");
+
+    auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
+    if (!cmp)
+      return rewriter.notifyMatchFailure(loop,
+                                         "Loop body must have single cmp op");
+
+    auto beforeTerm = cast<scf::ConditionOp>(beforeBody->getTerminator());
+    if (!llvm::hasSingleElement(cmp->getUses()) &&
+        beforeTerm.getCondition() == cmp.getResult())
+      return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+        diag << "Expected single condiditon use: " << *cmp;
+      });
+
+    if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
+      return rewriter.notifyMatchFailure(loop, "Invalid args order");
+
+    using Pred = arith::CmpIPredicate;
+    auto predicate = cmp.getPredicate();
+    if (predicate != Pred::slt && predicate != Pred::sgt)
+      return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+        diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
+      });
+
+    if (!checkIndexType(cmp, indexBitWidth))
+      return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+        diag << "Expected index-like type: " << *cmp;
+      });
+
+    BlockArgument iterVar;
+    Value end;
+    DominanceInfo dom;
+    for (bool reverse : {false, true}) {
+      auto expectedPred = reverse ? Pred::sgt : Pred::slt;
+      if (cmp.getPredicate() != expectedPred)
+        continue;
+
+      auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
+      auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
+
+      auto blockArg = dyn_cast<BlockArgument>(arg1);
+      if (!blockArg || blockArg.getOwner() != beforeBody)
+        continue;
+
+      if (!dom.properlyDominates(arg2, loop))
+        continue;
+
+      iterVar = blockArg;
+      end = arg2;
+      break;
+    }
+
+    if (!iterVar)
+      return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+        diag << "Unrecognized cmp form: " << *cmp;
+      });
+
+    if (!llvm::hasNItems(iterVar.getUses(), 2))
+      return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+        diag << "Unrecognized iter var: " << iterVar;
+      });
+
+    Block *afterBody = loop.getAfterBody();
+    auto afterTerm = cast<scf::YieldOp>(afterBody->getTerminator());
+    auto argNumber = iterVar.getArgNumber();
+    auto afterTermIterArg = afterTerm.getResults()[argNumber];
+
+    auto iterVarAfter = afterBody->getArgument(argNumber);
+
+    Value step;
+    for (auto &use : iterVarAfter.getUses()) {
+      auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
+      if (!owner)
+        continue;
+
+      auto other =
+          (iterVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
+      if (!dom.properlyDominates(other, loop))
+        continue;
+
+      if (afterTermIterArg != owner.getResult())
+        continue;
+
+      step = other;
+      break;
+    }
+
+    if (!step)
+      return rewriter.notifyMatchFailure(loop,
+                                         "Didn't found suitable 'add' op");
+
+    auto begin = loop.getInits()[argNumber];
+
+    auto loc = loop.getLoc();
+    auto indexType = rewriter.getIndexType();
+    auto toIndex = [&](Value val) -> Value {
+      if (val.getType() != indexType)
+        return rewriter.create<arith::IndexCastOp>(loc, indexType, val);
+
+      return val;
+    };
+    begin = toIndex(begin);
+    end = toIndex(end);
+    step = toIndex(step);
+
+    llvm::SmallVector<Value> mapping;
+    mapping.reserve(loop.getInits().size());
+    for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
+      if (i == argNumber)
+        continue;
+
+      mapping.emplace_back(init);
+    }
+
+    auto emptyBuidler = [](OpBuilder &, Location, Value, ValueRange) {};
+    auto newLoop = rewriter.create<scf::ForOp>(loc, begin, end, step, mapping,
+                                               emptyBuidler);
+
+    Block *newBody = newLoop.getBody();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPointToStart(newBody);
+    Value newIterVar = newBody->getArgument(0);
+    if (newIterVar.getType() != iterVar.getType())
+      newIterVar = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(),
+                                                       newIterVar);
+
+    mapping.clear();
+    auto newArgs = newBody->getArguments();
+    for (auto i : llvm::seq<size_t>(0, newArgs.size())) {
+      if (i < argNumber) {
+        mapping.emplace_back(newArgs[i + 1]);
+      } else if (i == argNumber) {
+        Value arg = newArgs.front();
+        if (arg.getType() != iterVar.getType())
+          arg =
+              rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), arg);
+        mapping.emplace_back(arg);
+      } else {
+        mapping.emplace_back(newArgs[i]);
+      }
+    }
+
+    rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
+                               mapping);
+
+    auto term = cast<scf::YieldOp>(newBody->getTerminator());
+
+    mapping.clear();
+    for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
+      if (i == argNumber)
+        continue;
+
+      mapping.emplace_back(arg);
+    }
+
+    rewriter.setInsertionPoint(term);
+    rewriter.replaceOpWithNewOp<scf::YieldOp>(term, mapping);
+
+    rewriter.setInsertionPointAfter(newLoop);
+    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
+    Value len = rewriter.create<arith::SubIOp>(loc, end, begin);
+    len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
+    len = rewriter.create<arith::DivSIOp>(loc, len, step);
+    len = rewriter.create<arith::SubIOp>(loc, len, one);
+    Value res = rewriter.create<arith::MulIOp>(loc, len, step);
+    res = rewriter.create<arith::AddIOp>(loc, begin, res);
+    if (res.getType() != iterVar.getType())
+      res = rewriter.create<arith::IndexCastOp>(loc, iterVar.getType(), res);
+
+    mapping.clear();
+    llvm::append_range(mapping, newLoop.getResults());
+    mapping.insert(mapping.begin() + argNumber, res);
+    rewriter.replaceOp(loop, mapping);
+    return success();
+  }
+
+private:
+  unsigned indexBitWidth = 0;
+};
+
+struct SCFUpliftWhileToFor final
+    : impl::SCFUpliftWhileToForBase<SCFUpliftWhileToFor> {
+  using SCFUpliftWhileToForBase::SCFUpliftWhileToForBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
+    RewritePatternSet patterns(ctx);
+    mlir::scf::populateUpliftWhileToForPatterns(patterns, this->indexBitWidth);
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns,
+                                                 unsigned indexBitwidth) {
+  patterns.add<UpliftWhileOp>(patterns.getContext(), indexBitwidth);
+}
diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir
new file mode 100644
index 00000000000000..52a5c0f3cd6347
--- /dev/null
+++ b/mlir/test/Dialect/SCF/uplift-while.mlir
@@ -0,0 +1,162 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-uplift-while-to-for{index-bitwidth=64}))' -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : index
+    scf.condition(%1) %arg3 : index
+  } do {
+  ^bb0(%arg3: index):
+    "test.test1"(%arg3) : (index) -> ()
+    %added = arith.addi %arg3, %arg2 : index
+    "test.test2"(%added) : (index) -> ()
+    scf.yield %added : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     "test.test1"(%[[I]]) : (index) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+//       CHECK:     "test.test2"(%[[INC]]) : (index) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     return %[[R7]] : index
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+    %1 = arith.cmpi sgt, %arg1, %arg3 : index
+    scf.condition(%1) %arg3 : index
+  } do {
+  ^bb0(%arg3: index):
+    "test.test1"(%arg3) : (index) -> ()
+    %added = arith.addi %arg3, %arg2 : index
+    "test.test2"(%added) : (index) -> ()
+    scf.yield %added : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     "test.test1"(%[[I]]) : (index) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index
+//       CHECK:     "test.test2"(%[[INC]]) : (index) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     return %[[R7]] : index
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = scf.while (%arg3 = %arg0) : (index) -> (index) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : index
+    scf.condition(%1) %arg3 : index
+  } do {
+  ^bb0(%arg3: index):
+    "test.test1"(%arg3) : (index) -> ()
+    %added = arith.addi %arg2, %arg3 : index
+    "test.test2"(%added) : (index) -> ()
+    scf.yield %added : index
+  }
+  return %0 : index
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     "test.test1"(%[[I]]) : (index) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[STEP]], %[[I]] : index
+//       CHECK:     "test.test2"(%[[INC]]) : (index) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     return %[[R7]] : index
+
+
+// -----
+
+func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) {
+  %c1 = arith.constant 1 : i32
+  %c2 = arith.constant 2.0 : f32
+  %0:3 = scf.while (%arg4 = %c1, %arg3 = %arg0, %arg5 = %c2) : (i32, index, f32) -> (i32, index, f32) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : index
+    scf.condition(%1) %arg4, %arg3, %arg5 : i32, index, f32
+  } do {
+  ^bb0(%arg4: i32, %arg3: index, %arg5: f32):
+    %1 = "test.test1"(%arg4) : (i32) -> i32
+    %added = arith.addi %arg3, %arg2 : index
+    %2 = "test.test2"(%arg5) : (f32) -> f32
+    scf.yield %1, %added, %2 : i32, index, f32
+  }
+  return %0#0, %0#2 : i32, f32
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> (i32, f32)
+//   CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : i32
+//   CHECK-DAG:     %[[C2:.*]] = arith.constant 2.000000e+00 : f32
+//       CHECK:     %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]]
+//  CHECK-SAME:     iter_args(%[[ARG1:.*]] = %[[C1]], %[[ARG2:.*]] = %[[C2]]) -> (i32, f32) {
+//       CHECK:     %[[T1:.*]] = "test.test1"(%[[ARG1]]) : (i32) -> i32
+//       CHECK:     %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
+//       CHECK:     scf.yield %[[T1]], %[[T2]] : i32, f32
+//       CHECK:     return %[[RES]]#0, %[[RES]]#1 : i32, f32
+
+// -----
+
+func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 {
+  %0 = scf.while (%arg3 = %arg0) : (i64) -> (i64) {
+    %1 = arith.cmpi slt, %arg3, %arg1 : i64
+    scf.condition(%1) %arg3 : i64
+  } do {
+  ^bb0(%arg3: i64):
+    "test.test1"(%arg3) : (i64) -> ()
+    %added = arith.addi %arg3, %arg2 : i64
+    "test.test2"(%added) : (i64) -> ()
+    scf.yield %added : i64
+  }
+  return %0 : i64
+}
+
+// CHECK-LABEL: func @uplift_while
+//  CHECK-SAME:     (%[[BEGINI:.*]]: i64, %[[ENDI:.*]]: i64, %[[STEPI:.*]]: i64) -> i64
+//       CHECK:     %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:     %[[BEGIN:.*]] = arith.index_cast %[[BEGINI]] : i64 to index
+//       CHECK:     %[[END:.*]] = arith.index_cast %[[ENDI]] : i64 to index
+//       CHECK:     %[[STEP:.*]] = arith.index_cast %[[STEPI]] : i64 to index
+//       CHECK:     scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] {
+//       CHECK:     %[[II:.*]] = arith.index_cast %[[I]] : index to i64
+//       CHECK:     "test.test1"(%[[II]]) : (i64) -> ()
+//       CHECK:     %[[INC:.*]] = arith.addi %[[II]], %[[STEPI]] : i64
+//       CHECK:     "test.test2"(%[[INC]]) : (i64) -> ()
+//       CHECK:     %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index
+//       CHECK:     %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index
+//       CHECK:     %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index
+//       CHECK:     %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index
+//       CHECK:     %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index
+//       CHECK:     %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index
+//       CHECK:     %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index
+//       CHECK:     %[[RES:.*]] = arith.index_cast %[[R7]] : index to i64
+//       CHECK:     return %[[RES]] : i64

@Hardcode84
Copy link
Contributor Author

Hardcode84 commented Dec 20, 2023

@kiranchandramohan this is part of our uplifting pipeline, I've mentioned in #75314 (comment)

Block *afterBody = loop.getAfterBody();
auto afterTerm = cast<scf::YieldOp>(afterBody->getTerminator());
auto argNumber = iterVar.getArgNumber();
auto afterTermIterArg = afterTerm.getResults()[argNumber];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need some kind of additional check here to make sure that the iter_args are forwarded from the "before" block to the "after" block in the same order?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's already enforced by if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) check.

@joker-eph
Copy link
Collaborator

Wouldn't this just be a good canonicalization for scf.while?

@Hardcode84
Copy link
Contributor Author

Wouldn't this just be a good canonicalization for scf.while?

I'm not sure if making this a canonicalization is good idea. I can imagine lowering pipeline looking like scf.for -> scf.while -> cf and such canonicalization will actively fight against it.

@joker-eph
Copy link
Collaborator

I'm not sure if making this a canonicalization is good idea. I can imagine lowering pipeline looking like scf.for -> scf.while -> cf and such canonicalization will actively fight against it.

I don't quite see the issue though: that just means that such lowering cannot run a canonicalization in the middle. That seems expected to me though, in particular if someone design a lowering with "in-dialect" intermediate steps (like scf.for -> scf.while).

@Hardcode84
Copy link
Contributor Author

I'm not against making it a part of canonicalization and it will work for our case, but I think there is still big chance that it can break downstream workflows.

So, if we are really want to make it part of canonicalization I suggest to split it into 2 parts:

  • Introduce it as separate pass first (this PR)
  • Promote to canonicalization with separate PR and forum PSA

So, in case we need to revert it, we can revert only the canonicalizer part

Hardcode84 added a commit to Hardcode84/llvm-project that referenced this pull request Dec 21, 2023
If `before` block args are directly forwarded to `scf.condition` make sure they are passes in the same order.
This is needed for `scf.while` uplifting llvm#76108
Hardcode84 added a commit to Hardcode84/llvm-project that referenced this pull request Dec 25, 2023
Move non-side-effecting ops from `before` region if all their args are defined outside the loop.
This is cleanup needed for `scf.while` -> `scf.for` uplifting llvm#76108 as it expects `before` block consisting of single `cmp` op.
@Hardcode84
Copy link
Contributor Author

ping

@joker-eph
Copy link
Collaborator

It's not quite clear to me why we need a PSA for new canonicalization?
Also please make this a test pass: it does not seem to have any future purpose if we agree on a canonicalization.

@Hardcode84
Copy link
Contributor Author

As I said, I still believe there is high chance it breaking downstream workflows, so PSA won't hurt. I will update this PR to make it a test pass.

@Hardcode84
Copy link
Contributor Author

Renamed to test pass

@Hardcode84
Copy link
Contributor Author

rebased

Hardcode84 added a commit to Hardcode84/llvm-project that referenced this pull request Apr 2, 2024
If `before` block args are directly forwarded to `scf.condition` make sure they are passes in the same order.
This is needed for `scf.while` uplifting llvm#76108
scf::ConditionOp beforeTerm = loop.getConditionOp();
if (!cmp->hasOneUse() && beforeTerm.getCondition() == cmp.getResult())
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
diag << "Expected single condiditon use: " << *cmp;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
diag << "Expected single condiditon use: " << *cmp;
diag << "Expected single condition use: " << *cmp;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

using namespace mlir;

namespace {
struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that we'd better decouple the transformation from the pattern.

That is: expose it as a C++ API FailureOr<scf::ForOp> upliftWhileToForLoop(scf::WhileOp loop); exposed in a public header, and have the client code decide their integration point (this could connect to the Transform dialect for example).

Copy link
Contributor Author

@Hardcode84 Hardcode84 Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Do we still want to make while-to-for uplifting part of while op canonicalization later?

@@ -154,4 +154,18 @@ def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
}];
}

def TestSCFUpliftWhileToFor : Pass<"test-scf-uplift-while-to-for"> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test passes should be in the test folder I believe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Hardcode84 added a commit that referenced this pull request Apr 2, 2024
…76195)

If `before` block args are directly forwarded to `scf.condition` make
sure they are passed in the same order.
This is needed for `scf.while` uplifting
#76108
@stellaraccident
Copy link
Contributor

stellaraccident commented Apr 2, 2024

Given that this is a piece of the uplift pipeline, I think it would be better to see the whole thing than eagerly install this as a canonicalizer in one step. Such loop transformation optimizations tend to grow into special things, and personally, I have come to dislike optimization pipelines that have action at a distance interactions with specific canonicalization patterns for proper functioning. It just makes them very hard to reason about.

I'm +1 on being conservative about making the specific patterns actual canonicalizations until the whole pipeline is in place (unless if obvious/universal, which I'm not sure this one is).

@Hardcode84
Copy link
Contributor Author

Also, it will need one more pattern (PR TBD) which probably doesn't qualify as canonicalization - duplicate and move ops from before block - to loop body and after, e.g:

scf.while(...) {
before:
  ...
  some_op()
  scf.condition ..
after:
  ...
}

to

scf.while(...) {
before:
  ...
  scf.condition ..
after:
  some_op()
  ...
}
some_op()

@Hardcode84
Copy link
Contributor Author

ping

Add uplifting from `scf.while` to `scf.for`.

This uplifting expects a very specifi ops pattern:
* `before` body consisting of single `arith.cmp` op
* `after` body containing `arith.addi`
* Iter var must be of type `index` or integer of specified width

We also have a set of patterns to clenaup `scf.while` loops to get them close to the desired form, they will be added in separate PRs.

This is part of upstreaming `numba-mlir` scf uplifting pipeline:
`cf -> scf.while -> scf.for -> scf.parallel`

Original code: https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/PromoteToParallel.cpp
@Hardcode84 Hardcode84 merged commit b153c05 into llvm:main Apr 15, 2024
@Hardcode84 Hardcode84 deleted the scf-uplift-while branch April 15, 2024 19:17
aniplcc pushed a commit to aniplcc/llvm-project that referenced this pull request Apr 15, 2024
Add uplifting from `scf.while` to `scf.for`.

This uplifting expects a very specific ops pattern:
* `before` block consisting of single `arith.cmp` op
* `after` block containing `arith.addi`

We also have a set of patterns to cleanup `scf.while` loops to get them
close to the desired form, they will be added in separate PRs.

This is part of upstreaming `numba-mlir` scf uplifting pipeline: `cf ->
scf.while -> scf.for -> scf.parallel`

Original code:
https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/PromoteToParallel.cpp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants