Skip to content

Introduce new Unroll And Jam loop transform for SCF/Affine loops #94142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 21, 2024

Conversation

AviadCo
Copy link
Contributor

@AviadCo AviadCo commented Jun 2, 2024

Unroll And Jam was supported in affine dialect long time ago using pass. This commit exposes the pattern using transform and in addition adds partial support for SCF loops.

@llvmbot
Copy link
Member

llvmbot commented Jun 2, 2024

@llvm/pr-subscribers-mlir-scf

Author: Aviad Cohen (AviadCo)

Changes

Unroll And Jam was supported in affine dialect long time ago using pass. This commit exposes the pattern using transform and in addition adds partial support for SCF loops.


Patch is 37.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94142.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td (+35-1)
  • (modified) mlir/include/mlir/Dialect/SCF/Utils/Utils.h (+8)
  • (added) mlir/include/mlir/IR/LoopUtils.h (+56)
  • (modified) mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp (+2-29)
  • (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+22)
  • (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+209-7)
  • (modified) mlir/test/Dialect/SCF/transform-ops-invalid.mlir (+99)
  • (modified) mlir/test/Dialect/SCF/transform-ops.mlir (+215)
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 5eefe2664d0a1..29651bed296e5 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -243,7 +243,41 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
     This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
     in the return. If all the operations referred to by the `target` operand
     unroll properly, the transform succeeds. Otherwise the transform produces a
-    silencebale failure.
+    silenceable failure.
+
+    Does not return handles as the operation may result in the loop being
+    removed after a full unrolling.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       ConfinedAttr<I64Attr, [IntPositive]>:$factor);
+
+  let assemblyFormat = "$target attr-dict `:` type($target)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+def LoopUnrollAndJamOp : Op<Transform_Dialect, "loop.unroll_and_jam",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let summary = "Unrolls and jam the given loop with the given unroll factor";
+  let description = [{
+    Unrolls & jams each loop associated with the given handle to have up to the given
+    number of loop body copies per iteration. If the unroll factor is larger
+    than the loop trip count, the latter is used as the unroll factor instead.
+
+    #### Return modes
+
+    This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
+    in the return. If all the operations referred to by the `target` operand
+    unroll properly, the transform succeeds. Otherwise the transform produces a
+    silenceable failure.
 
     Does not return handles as the operation may result in the loop being
     removed after a full unrolling.
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index bc09cc7f7fa5e..ad3113c960d37 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -120,6 +120,14 @@ LogicalResult loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
+/// Unrolls and jam this for operation by the specified unroll factor. Returns
+/// failure if the loop cannot be unrolled either due to restrictions or due to
+/// invalid unroll factors. In case of unroll factor of 1, the function bails
+/// out without doing anything (returns success). Currently, only constant trip
+/// count that are divided by the unroll factor is supported. Currently, for
+/// operations with results are not supported.
+LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor);
+
 /// Tile a nest of standard for loops rooted at `rootForOp` by finding such
 /// parametric tile sizes that the outer loops have a fixed number of iterations
 /// as defined in `sizes`.
diff --git a/mlir/include/mlir/IR/LoopUtils.h b/mlir/include/mlir/IR/LoopUtils.h
new file mode 100644
index 0000000000000..410eec223e637
--- /dev/null
+++ b/mlir/include/mlir/IR/LoopUtils.h
@@ -0,0 +1,56 @@
+//===- LoopUtils.h -  LoopUtils Support ---------------------*- C++
+//-*-=============//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions for the action framework. This framework
+// allows for external entities to control certain actions taken by the compiler
+// by registering handler functions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_LOOP_UTILS_H
+#define MLIR_IR_LOOP_UTILS_H
+
+#include "mlir/IR/BuiltinOps.h"
+
+namespace mlir {
+
+// Gathers all maximal sub-blocks of operations that do not themselves
+// include a `OpTy` (an operation could have a descendant `OpTy` though
+// in its tree). Ignore the block terminators.
+template <typename OpTy>
+struct JamBlockGatherer {
+  // Store iterators to the first and last op of each sub-block found.
+  llvm::SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
+
+  // This is a linear time walk.
+  void walk(Operation *op) {
+    for (auto &region : op->getRegions())
+      for (auto &block : region)
+        walk(block);
+  }
+
+  void walk(Block &block) {
+    assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
+           "expected block to have a terminator");
+    for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
+      auto subBlockStart = it;
+      while (it != e && !isa<OpTy>(&*it))
+        ++it;
+      if (it != subBlockStart)
+        subBlocks.emplace_back(subBlockStart, std::prev(it));
+      // Process all for ops that appear next.
+      while (it != e && isa<OpTy>(&*it))
+        walk(&*it++);
+    }
+  }
+};
+
+} // namespace mlir
+
+#endif // MLIR_IR_LOOP_UTILS_H
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 268050a30e002..6d35f06faa45b 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/LoopUtils.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/MapVector.h"
@@ -1102,34 +1103,6 @@ static bool areInnerBoundsInvariant(AffineForOp forOp) {
   return !walkResult.wasInterrupted();
 }
 
-// Gathers all maximal sub-blocks of operations that do not themselves
-// include a for op (a operation could have a descendant for op though
-// in its tree).  Ignore the block terminators.
-struct JamBlockGatherer {
-  // Store iterators to the first and last op of each sub-block found.
-  std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
-
-  // This is a linear time walk.
-  void walk(Operation *op) {
-    for (auto &region : op->getRegions())
-      for (auto &block : region)
-        walk(block);
-  }
-
-  void walk(Block &block) {
-    for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
-      auto subBlockStart = it;
-      while (it != e && !isa<AffineForOp>(&*it))
-        ++it;
-      if (it != subBlockStart)
-        subBlocks.emplace_back(subBlockStart, std::prev(it));
-      // Process all for ops that appear next.
-      while (it != e && isa<AffineForOp>(&*it))
-        walk(&*it++);
-    }
-  }
-};
-
 /// Unrolls and jams this loop by the specified factor.
 LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
                                                   uint64_t unrollJamFactor) {
@@ -1159,7 +1132,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
     return failure();
 
   // Gather all sub-blocks to jam upon the loop being unrolled.
-  JamBlockGatherer jbg;
+  JamBlockGatherer<AffineForOp> jbg;
   jbg.walk(forOp);
   auto &subBlocks = jbg.subBlocks;
 
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 69f83d8bd70da..df5701a40dc1b 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -304,6 +304,28 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// LoopUnrollAndJamOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *op,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  LogicalResult result(failure());
+  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
+    result = loopUnrollJamByFactor(scfFor, getFactor());
+  else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
+    result = loopUnrollJamByFactor(affineFor, getFactor());
+
+  if (failed(result)) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "failed to unroll and jam";
+    return diag;
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LoopCoalesceOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 6658cca03eba7..07ae3507c3ded 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/LoopUtils.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Support/MathExtras.h"
@@ -26,9 +27,15 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include <cstdint>
 
 using namespace mlir;
 
+#define DEBUG_TYPE "scf-utils"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 namespace {
 // This structure is to pass and return sets of loop parameters without
 // confusing the order.
@@ -296,6 +303,23 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
   return builder.create<arith::DivUIOp>(loc, sum, divisor);
 }
 
+static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
+  std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
+  std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
+  std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
+  if (lbCstOp && ubCstOp && stepCstOp) {
+    // Constant loop bounds computation.
+    int64_t lbCst = lbCstOp.value();
+    int64_t ubCst = ubCstOp.value();
+    int64_t stepCst = stepCstOp.value();
+    assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
+           "expected positive loop bounds and step");
+    return mlir::ceilDiv(ubCst - lbCst, stepCst);
+  }
+
+  return {};
+}
+
 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -375,22 +399,21 @@ LogicalResult mlir::loopUnrollByFactor(
   std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
   std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
   std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
-  if (lbCstOp && ubCstOp && stepCstOp) {
+  auto constTripCount = getConstantTripCount(forOp);
+  if (constTripCount) {
     // Constant loop bounds computation.
     int64_t lbCst = lbCstOp.value();
     int64_t ubCst = ubCstOp.value();
     int64_t stepCst = stepCstOp.value();
-    assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
-           "expected positive loop bounds and step");
-    int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
-
     if (unrollFactor == 1) {
-      if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter)))
+      if (*constTripCount == 1 &&
+          failed(forOp.promoteIfSingleIteration(rewriter)))
         return failure();
       return success();
     }
 
-    int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
+    int64_t tripCountEvenMultiple =
+        *constTripCount - (*constTripCount % unrollFactor);
     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
     int64_t stepUnrolledCst = stepCst * unrollFactor;
 
@@ -473,6 +496,185 @@ LogicalResult mlir::loopUnrollByFactor(
   return success();
 }
 
+/// Check if bounds of all inner loops are defined outside of `forOp`
+/// and return false if not.
+static bool areInnerBoundsInvariant(scf::ForOp forOp) {
+  auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
+    if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+      return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+  return !walkResult.wasInterrupted();
+}
+
+/// Unrolls and jams this loop by the specified factor.
+LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
+                                          uint64_t unrollJamFactor) {
+  assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
+
+  if (unrollJamFactor == 1)
+    return success();
+
+  // If any control operand of any inner loop of `forOp` is defined within
+  // `forOp`, no unroll jam.
+  if (!areInnerBoundsInvariant(forOp)) {
+    LDBG("failed to unroll and jam: inner bounds are not invariant");
+    return failure();
+  }
+
+  // Currently, for operations with results are not supported.
+  if (forOp->getNumResults() > 0) {
+    LDBG("failed to unroll and jam: unsupported loop with results");
+    return failure();
+  }
+
+  // Currently, only constant trip count that divided by the unroll factor is
+  // supported.
+  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  if (!tripCount) {
+    // If the trip count is dynamic, do not unroll & jam.
+    LDBG("failed to unroll and jam: trip count could not be determined");
+    return failure();
+  }
+  if (unrollJamFactor > *tripCount) {
+    LDBG("unroll and jam factor is greater than trip count, set factor to trip "
+         "count");
+    unrollJamFactor = *tripCount;
+  } else if (*tripCount % unrollJamFactor != 0) {
+    LDBG("failed to unroll and jam: unsupported trip count that is not a "
+         "multiple of unroll jam factor");
+    return failure();
+  }
+
+  // Nothing in the loop body other than the terminator.
+  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
+    return success();
+
+  // Gather all sub-blocks to jam upon the loop being unrolled.
+  JamBlockGatherer<scf::ForOp> jbg;
+  jbg.walk(forOp);
+  auto &subBlocks = jbg.subBlocks;
+
+  // Collect inner loops.
+  SmallVector<scf::ForOp, 4> innerLoops;
+  forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
+
+  // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
+  // iteration. There are (`unrollJamFactor` - 1) iterations.
+  SmallVector<IRMapping, 4> operandMaps(unrollJamFactor - 1);
+
+  // For any loop with iter_args, replace it with a new loop that has
+  // `unrollJamFactor` copies of its iterOperands, iter_args and yield
+  // operands.
+  SmallVector<scf::ForOp, 4> newInnerLoops;
+  IRRewriter rewriter(forOp.getContext());
+  for (scf::ForOp oldForOp : innerLoops) {
+    SmallVector<Value> dupIterOperands, dupYieldOperands;
+    ValueRange oldIterOperands = oldForOp.getInits();
+    ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
+    ValueRange oldYieldOperands =
+        cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
+    // Get additional iterOperands, iterArgs, and yield operands. We will
+    // fix iterOperands and yield operands after cloning of sub-blocks.
+    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+      dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
+      dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
+    }
+    // Create a new loop with additional iterOperands, iter_args and yield
+    // operands. This new loop will take the loop body of the original loop.
+    bool forOpReplaced = oldForOp == forOp;
+    scf::ForOp newForOp =
+        cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
+            rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
+            [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+              return dupYieldOperands;
+            }));
+    newInnerLoops.push_back(newForOp);
+    // `forOp` has been replaced with a new loop.
+    if (forOpReplaced)
+      forOp = newForOp;
+    // Update `operandMaps` for `newForOp` iterArgs and results.
+    ValueRange newIterArgs = newForOp.getRegionIterArgs();
+    unsigned oldNumIterArgs = oldIterArgs.size();
+    ValueRange newResults = newForOp.getResults();
+    unsigned oldNumResults = newResults.size() / unrollJamFactor;
+    assert(oldNumIterArgs == oldNumResults &&
+           "oldNumIterArgs must be the same as oldNumResults");
+    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+      for (unsigned j = 0; j < oldNumIterArgs; ++j) {
+        // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
+        // results. Update `operandMaps[i - 1]` to map old iterArgs and results
+        // to those in the `i`th new set.
+        operandMaps[i - 1].map(newIterArgs[j],
+                               newIterArgs[i * oldNumIterArgs + j]);
+        operandMaps[i - 1].map(newResults[j],
+                               newResults[i * oldNumResults + j]);
+      }
+    }
+  }
+
+  // Scale the step of loop being unroll-jammed by the unroll-jam factor.
+  rewriter.setInsertionPoint(forOp);
+  int64_t step = forOp.getConstantStep()->getSExtValue();
+  auto newStep = rewriter.createOrFold<arith::MulIOp>(
+      forOp.getLoc(), forOp.getStep(),
+      rewriter.createOrFold<arith::ConstantOp>(
+          forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
+  forOp.setStep(newStep);
+  auto forOpIV = forOp.getInductionVar();
+
+  // Unroll and jam (appends unrollJamFactor - 1 additional copies).
+  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+    for (auto &subBlock : subBlocks) {
+      // Builder to insert unroll-jammed bodies. Insert right at the end of
+      // sub-block.
+      OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
+
+      // If the induction variable is used, create a remapping to the value for
+      // this unrolled instance.
+      if (!forOpIV.use_empty()) {
+        // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
+        auto ivTag = builder.createOrFold<arith::ConstantOp>(
+            forOp.getLoc(), builder.getIndexAttr(step * i));
+        auto ivUnroll =
+            builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
+        operandMaps[i - 1].map(forOpIV, ivUnroll);
+      }
+      // Clone the sub-block being unroll-jammed.
+      for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
+        builder.clone(*it, operandMaps[i - 1]);
+    }
+    // Fix iterOperands and yield op operands of newly created loops.
+    for (auto newForOp : newInnerLoops) {
+      unsigned oldNumIterOperands =
+          newForOp.getNumRegionIterArgs() / unrollJamFactor;
+      unsigned numControlOperands = newForOp.getNumControlOperands();
+      auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+      unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
+      assert(oldNumIterOperands == oldNumYieldOperands &&
+             "oldNumIterOperands must be the same as oldNumYieldOperands");
+      for (unsigned j = 0; j < oldNumIterOperands; ++j) {
+        // The `i`th duplication of an old iterOperand or yield op operand
+        // needs to be replaced with a mapped value from `operandMaps[i - 1]`
+        // if such mapped value exists.
+        newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
+                            operandMaps[i - 1].lookupOrDefault(
+                   ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 2, 2024

@llvm/pr-subscribers-mlir

Author: Aviad Cohen (AviadCo)

Changes

Unroll And Jam was supported in affine dialect long time ago using pass. This commit exposes the pattern using transform and in addition adds partial support for SCF loops.


Patch is 37.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94142.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td (+35-1)
  • (modified) mlir/include/mlir/Dialect/SCF/Utils/Utils.h (+8)
  • (added) mlir/include/mlir/IR/LoopUtils.h (+56)
  • (modified) mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp (+2-29)
  • (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+22)
  • (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+209-7)
  • (modified) mlir/test/Dialect/SCF/transform-ops-invalid.mlir (+99)
  • (modified) mlir/test/Dialect/SCF/transform-ops.mlir (+215)
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 5eefe2664d0a1..29651bed296e5 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -243,7 +243,41 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
     This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
     in the return. If all the operations referred to by the `target` operand
     unroll properly, the transform succeeds. Otherwise the transform produces a
-    silencebale failure.
+    silenceable failure.
+
+    Does not return handles as the operation may result in the loop being
+    removed after a full unrolling.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       ConfinedAttr<I64Attr, [IntPositive]>:$factor);
+
+  let assemblyFormat = "$target attr-dict `:` type($target)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+def LoopUnrollAndJamOp : Op<Transform_Dialect, "loop.unroll_and_jam",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let summary = "Unrolls and jam the given loop with the given unroll factor";
+  let description = [{
+    Unrolls & jams each loop associated with the given handle to have up to the given
+    number of loop body copies per iteration. If the unroll factor is larger
+    than the loop trip count, the latter is used as the unroll factor instead.
+
+    #### Return modes
+
+    This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
+    in the return. If all the operations referred to by the `target` operand
+    unroll properly, the transform succeeds. Otherwise the transform produces a
+    silenceable failure.
 
     Does not return handles as the operation may result in the loop being
     removed after a full unrolling.
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index bc09cc7f7fa5e..ad3113c960d37 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -120,6 +120,14 @@ LogicalResult loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
+/// Unrolls and jam this for operation by the specified unroll factor. Returns
+/// failure if the loop cannot be unrolled either due to restrictions or due to
+/// invalid unroll factors. In case of unroll factor of 1, the function bails
+/// out without doing anything (returns success). Currently, only constant trip
+/// count that are divided by the unroll factor is supported. Currently, for
+/// operations with results are not supported.
+LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor);
+
 /// Tile a nest of standard for loops rooted at `rootForOp` by finding such
 /// parametric tile sizes that the outer loops have a fixed number of iterations
 /// as defined in `sizes`.
diff --git a/mlir/include/mlir/IR/LoopUtils.h b/mlir/include/mlir/IR/LoopUtils.h
new file mode 100644
index 0000000000000..410eec223e637
--- /dev/null
+++ b/mlir/include/mlir/IR/LoopUtils.h
@@ -0,0 +1,56 @@
+//===- LoopUtils.h -  LoopUtils Support ---------------------*- C++
+//-*-=============//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions for the action framework. This framework
+// allows for external entities to control certain actions taken by the compiler
+// by registering handler functions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_LOOP_UTILS_H
+#define MLIR_IR_LOOP_UTILS_H
+
+#include "mlir/IR/BuiltinOps.h"
+
+namespace mlir {
+
+// Gathers all maximal sub-blocks of operations that do not themselves
+// include a `OpTy` (an operation could have a descendant `OpTy` though
+// in its tree). Ignore the block terminators.
+template <typename OpTy>
+struct JamBlockGatherer {
+  // Store iterators to the first and last op of each sub-block found.
+  llvm::SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
+
+  // This is a linear time walk.
+  void walk(Operation *op) {
+    for (auto &region : op->getRegions())
+      for (auto &block : region)
+        walk(block);
+  }
+
+  void walk(Block &block) {
+    assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
+           "expected block to have a terminator");
+    for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
+      auto subBlockStart = it;
+      while (it != e && !isa<OpTy>(&*it))
+        ++it;
+      if (it != subBlockStart)
+        subBlocks.emplace_back(subBlockStart, std::prev(it));
+      // Process all for ops that appear next.
+      while (it != e && isa<OpTy>(&*it))
+        walk(&*it++);
+    }
+  }
+};
+
+} // namespace mlir
+
+#endif // MLIR_IR_LOOP_UTILS_H
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 268050a30e002..6d35f06faa45b 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/LoopUtils.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/MapVector.h"
@@ -1102,34 +1103,6 @@ static bool areInnerBoundsInvariant(AffineForOp forOp) {
   return !walkResult.wasInterrupted();
 }
 
-// Gathers all maximal sub-blocks of operations that do not themselves
-// include a for op (a operation could have a descendant for op though
-// in its tree).  Ignore the block terminators.
-struct JamBlockGatherer {
-  // Store iterators to the first and last op of each sub-block found.
-  std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
-
-  // This is a linear time walk.
-  void walk(Operation *op) {
-    for (auto &region : op->getRegions())
-      for (auto &block : region)
-        walk(block);
-  }
-
-  void walk(Block &block) {
-    for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
-      auto subBlockStart = it;
-      while (it != e && !isa<AffineForOp>(&*it))
-        ++it;
-      if (it != subBlockStart)
-        subBlocks.emplace_back(subBlockStart, std::prev(it));
-      // Process all for ops that appear next.
-      while (it != e && isa<AffineForOp>(&*it))
-        walk(&*it++);
-    }
-  }
-};
-
 /// Unrolls and jams this loop by the specified factor.
 LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
                                                   uint64_t unrollJamFactor) {
@@ -1159,7 +1132,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
     return failure();
 
   // Gather all sub-blocks to jam upon the loop being unrolled.
-  JamBlockGatherer jbg;
+  JamBlockGatherer<AffineForOp> jbg;
   jbg.walk(forOp);
   auto &subBlocks = jbg.subBlocks;
 
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 69f83d8bd70da..df5701a40dc1b 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -304,6 +304,28 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// LoopUnrollAndJamOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *op,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  LogicalResult result(failure());
+  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
+    result = loopUnrollJamByFactor(scfFor, getFactor());
+  else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
+    result = loopUnrollJamByFactor(affineFor, getFactor());
+
+  if (failed(result)) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "failed to unroll and jam";
+    return diag;
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LoopCoalesceOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 6658cca03eba7..07ae3507c3ded 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/LoopUtils.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Support/MathExtras.h"
@@ -26,9 +27,15 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include <cstdint>
 
 using namespace mlir;
 
+#define DEBUG_TYPE "scf-utils"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 namespace {
 // This structure is to pass and return sets of loop parameters without
 // confusing the order.
@@ -296,6 +303,23 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
   return builder.create<arith::DivUIOp>(loc, sum, divisor);
 }
 
+static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
+  std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
+  std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
+  std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
+  if (lbCstOp && ubCstOp && stepCstOp) {
+    // Constant loop bounds computation.
+    int64_t lbCst = lbCstOp.value();
+    int64_t ubCst = ubCstOp.value();
+    int64_t stepCst = stepCstOp.value();
+    assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
+           "expected positive loop bounds and step");
+    return mlir::ceilDiv(ubCst - lbCst, stepCst);
+  }
+
+  return {};
+}
+
 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -375,22 +399,21 @@ LogicalResult mlir::loopUnrollByFactor(
   std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
   std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
   std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
-  if (lbCstOp && ubCstOp && stepCstOp) {
+  auto constTripCount = getConstantTripCount(forOp);
+  if (constTripCount) {
     // Constant loop bounds computation.
     int64_t lbCst = lbCstOp.value();
     int64_t ubCst = ubCstOp.value();
     int64_t stepCst = stepCstOp.value();
-    assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
-           "expected positive loop bounds and step");
-    int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
-
     if (unrollFactor == 1) {
-      if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter)))
+      if (*constTripCount == 1 &&
+          failed(forOp.promoteIfSingleIteration(rewriter)))
         return failure();
       return success();
     }
 
-    int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
+    int64_t tripCountEvenMultiple =
+        *constTripCount - (*constTripCount % unrollFactor);
     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
     int64_t stepUnrolledCst = stepCst * unrollFactor;
 
@@ -473,6 +496,185 @@ LogicalResult mlir::loopUnrollByFactor(
   return success();
 }
 
+/// Check if bounds of all inner loops are defined outside of `forOp`
+/// and return false if not.
+static bool areInnerBoundsInvariant(scf::ForOp forOp) {
+  auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
+    if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+      return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+  return !walkResult.wasInterrupted();
+}
+
+/// Unrolls and jams this loop by the specified factor.
+LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
+                                          uint64_t unrollJamFactor) {
+  assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
+
+  if (unrollJamFactor == 1)
+    return success();
+
+  // If any control operand of any inner loop of `forOp` is defined within
+  // `forOp`, no unroll jam.
+  if (!areInnerBoundsInvariant(forOp)) {
+    LDBG("failed to unroll and jam: inner bounds are not invariant");
+    return failure();
+  }
+
+  // Currently, for operations with results are not supported.
+  if (forOp->getNumResults() > 0) {
+    LDBG("failed to unroll and jam: unsupported loop with results");
+    return failure();
+  }
+
+  // Currently, only constant trip count that divided by the unroll factor is
+  // supported.
+  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  if (!tripCount) {
+    // If the trip count is dynamic, do not unroll & jam.
+    LDBG("failed to unroll and jam: trip count could not be determined");
+    return failure();
+  }
+  if (unrollJamFactor > *tripCount) {
+    LDBG("unroll and jam factor is greater than trip count, set factor to trip "
+         "count");
+    unrollJamFactor = *tripCount;
+  } else if (*tripCount % unrollJamFactor != 0) {
+    LDBG("failed to unroll and jam: unsupported trip count that is not a "
+         "multiple of unroll jam factor");
+    return failure();
+  }
+
+  // Nothing in the loop body other than the terminator.
+  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
+    return success();
+
+  // Gather all sub-blocks to jam upon the loop being unrolled.
+  JamBlockGatherer<scf::ForOp> jbg;
+  jbg.walk(forOp);
+  auto &subBlocks = jbg.subBlocks;
+
+  // Collect inner loops.
+  SmallVector<scf::ForOp, 4> innerLoops;
+  forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
+
+  // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
+  // iteration. There are (`unrollJamFactor` - 1) iterations.
+  SmallVector<IRMapping, 4> operandMaps(unrollJamFactor - 1);
+
+  // For any loop with iter_args, replace it with a new loop that has
+  // `unrollJamFactor` copies of its iterOperands, iter_args and yield
+  // operands.
+  SmallVector<scf::ForOp, 4> newInnerLoops;
+  IRRewriter rewriter(forOp.getContext());
+  for (scf::ForOp oldForOp : innerLoops) {
+    SmallVector<Value> dupIterOperands, dupYieldOperands;
+    ValueRange oldIterOperands = oldForOp.getInits();
+    ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
+    ValueRange oldYieldOperands =
+        cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
+    // Get additional iterOperands, iterArgs, and yield operands. We will
+    // fix iterOperands and yield operands after cloning of sub-blocks.
+    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+      dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
+      dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
+    }
+    // Create a new loop with additional iterOperands, iter_args and yield
+    // operands. This new loop will take the loop body of the original loop.
+    bool forOpReplaced = oldForOp == forOp;
+    scf::ForOp newForOp =
+        cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
+            rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
+            [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+              return dupYieldOperands;
+            }));
+    newInnerLoops.push_back(newForOp);
+    // `forOp` has been replaced with a new loop.
+    if (forOpReplaced)
+      forOp = newForOp;
+    // Update `operandMaps` for `newForOp` iterArgs and results.
+    ValueRange newIterArgs = newForOp.getRegionIterArgs();
+    unsigned oldNumIterArgs = oldIterArgs.size();
+    ValueRange newResults = newForOp.getResults();
+    unsigned oldNumResults = newResults.size() / unrollJamFactor;
+    assert(oldNumIterArgs == oldNumResults &&
+           "oldNumIterArgs must be the same as oldNumResults");
+    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+      for (unsigned j = 0; j < oldNumIterArgs; ++j) {
+        // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
+        // results. Update `operandMaps[i - 1]` to map old iterArgs and results
+        // to those in the `i`th new set.
+        operandMaps[i - 1].map(newIterArgs[j],
+                               newIterArgs[i * oldNumIterArgs + j]);
+        operandMaps[i - 1].map(newResults[j],
+                               newResults[i * oldNumResults + j]);
+      }
+    }
+  }
+
+  // Scale the step of loop being unroll-jammed by the unroll-jam factor.
+  rewriter.setInsertionPoint(forOp);
+  int64_t step = forOp.getConstantStep()->getSExtValue();
+  auto newStep = rewriter.createOrFold<arith::MulIOp>(
+      forOp.getLoc(), forOp.getStep(),
+      rewriter.createOrFold<arith::ConstantOp>(
+          forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
+  forOp.setStep(newStep);
+  auto forOpIV = forOp.getInductionVar();
+
+  // Unroll and jam (appends unrollJamFactor - 1 additional copies).
+  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+    for (auto &subBlock : subBlocks) {
+      // Builder to insert unroll-jammed bodies. Insert right at the end of
+      // sub-block.
+      OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
+
+      // If the induction variable is used, create a remapping to the value for
+      // this unrolled instance.
+      if (!forOpIV.use_empty()) {
+        // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
+        auto ivTag = builder.createOrFold<arith::ConstantOp>(
+            forOp.getLoc(), builder.getIndexAttr(step * i));
+        auto ivUnroll =
+            builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
+        operandMaps[i - 1].map(forOpIV, ivUnroll);
+      }
+      // Clone the sub-block being unroll-jammed.
+      for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
+        builder.clone(*it, operandMaps[i - 1]);
+    }
+    // Fix iterOperands and yield op operands of newly created loops.
+    for (auto newForOp : newInnerLoops) {
+      unsigned oldNumIterOperands =
+          newForOp.getNumRegionIterArgs() / unrollJamFactor;
+      unsigned numControlOperands = newForOp.getNumControlOperands();
+      auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+      unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
+      assert(oldNumIterOperands == oldNumYieldOperands &&
+             "oldNumIterOperands must be the same as oldNumYieldOperands");
+      for (unsigned j = 0; j < oldNumIterOperands; ++j) {
+        // The `i`th duplication of an old iterOperand or yield op operand
+        // needs to be replaced with a mapped value from `operandMaps[i - 1]`
+        // if such mapped value exists.
+        newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
+                            operandMaps[i - 1].lookupOrDefault(
+                   ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 2, 2024

@llvm/pr-subscribers-mlir-core

Author: Aviad Cohen (AviadCo)

Changes

Unroll And Jam was supported in affine dialect long time ago using pass. This commit exposes the pattern using transform and in addition adds partial support for SCF loops.


Patch is 37.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94142.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td (+35-1)
  • (modified) mlir/include/mlir/Dialect/SCF/Utils/Utils.h (+8)
  • (added) mlir/include/mlir/IR/LoopUtils.h (+56)
  • (modified) mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp (+2-29)
  • (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+22)
  • (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+209-7)
  • (modified) mlir/test/Dialect/SCF/transform-ops-invalid.mlir (+99)
  • (modified) mlir/test/Dialect/SCF/transform-ops.mlir (+215)
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 5eefe2664d0a1..29651bed296e5 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -243,7 +243,41 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
     This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
     in the return. If all the operations referred to by the `target` operand
     unroll properly, the transform succeeds. Otherwise the transform produces a
-    silencebale failure.
+    silenceable failure.
+
+    Does not return handles as the operation may result in the loop being
+    removed after a full unrolling.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       ConfinedAttr<I64Attr, [IntPositive]>:$factor);
+
+  let assemblyFormat = "$target attr-dict `:` type($target)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+def LoopUnrollAndJamOp : Op<Transform_Dialect, "loop.unroll_and_jam",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let summary = "Unrolls and jam the given loop with the given unroll factor";
+  let description = [{
+    Unrolls & jams each loop associated with the given handle to have up to the given
+    number of loop body copies per iteration. If the unroll factor is larger
+    than the loop trip count, the latter is used as the unroll factor instead.
+
+    #### Return modes
+
+    This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
+    in the return. If all the operations referred to by the `target` operand
+    unroll properly, the transform succeeds. Otherwise the transform produces a
+    silenceable failure.
 
     Does not return handles as the operation may result in the loop being
     removed after a full unrolling.
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index bc09cc7f7fa5e..ad3113c960d37 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -120,6 +120,14 @@ LogicalResult loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
+/// Unrolls and jam this for operation by the specified unroll factor. Returns
+/// failure if the loop cannot be unrolled either due to restrictions or due to
+/// invalid unroll factors. In case of unroll factor of 1, the function bails
+/// out without doing anything (returns success). Currently, only constant trip
+/// count that are divided by the unroll factor is supported. Currently, for
+/// operations with results are not supported.
+LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor);
+
 /// Tile a nest of standard for loops rooted at `rootForOp` by finding such
 /// parametric tile sizes that the outer loops have a fixed number of iterations
 /// as defined in `sizes`.
diff --git a/mlir/include/mlir/IR/LoopUtils.h b/mlir/include/mlir/IR/LoopUtils.h
new file mode 100644
index 0000000000000..410eec223e637
--- /dev/null
+++ b/mlir/include/mlir/IR/LoopUtils.h
@@ -0,0 +1,56 @@
+//===- LoopUtils.h -  LoopUtils Support ---------------------*- C++
+//-*-=============//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions for the action framework. This framework
+// allows for external entities to control certain actions taken by the compiler
+// by registering handler functions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_LOOP_UTILS_H
+#define MLIR_IR_LOOP_UTILS_H
+
+#include "mlir/IR/BuiltinOps.h"
+
+namespace mlir {
+
+// Gathers all maximal sub-blocks of operations that do not themselves
+// include a `OpTy` (an operation could have a descendant `OpTy` though
+// in its tree). Ignore the block terminators.
+template <typename OpTy>
+struct JamBlockGatherer {
+  // Store iterators to the first and last op of each sub-block found.
+  llvm::SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
+
+  // This is a linear time walk.
+  void walk(Operation *op) {
+    for (auto &region : op->getRegions())
+      for (auto &block : region)
+        walk(block);
+  }
+
+  void walk(Block &block) {
+    assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
+           "expected block to have a terminator");
+    for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
+      auto subBlockStart = it;
+      while (it != e && !isa<OpTy>(&*it))
+        ++it;
+      if (it != subBlockStart)
+        subBlocks.emplace_back(subBlockStart, std::prev(it));
+      // Process all for ops that appear next.
+      while (it != e && isa<OpTy>(&*it))
+        walk(&*it++);
+    }
+  }
+};
+
+} // namespace mlir
+
+#endif // MLIR_IR_LOOP_UTILS_H
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 268050a30e002..6d35f06faa45b 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/LoopUtils.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/MapVector.h"
@@ -1102,34 +1103,6 @@ static bool areInnerBoundsInvariant(AffineForOp forOp) {
   return !walkResult.wasInterrupted();
 }
 
-// Gathers all maximal sub-blocks of operations that do not themselves
-// include a for op (a operation could have a descendant for op though
-// in its tree).  Ignore the block terminators.
-struct JamBlockGatherer {
-  // Store iterators to the first and last op of each sub-block found.
-  std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
-
-  // This is a linear time walk.
-  void walk(Operation *op) {
-    for (auto &region : op->getRegions())
-      for (auto &block : region)
-        walk(block);
-  }
-
-  void walk(Block &block) {
-    for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
-      auto subBlockStart = it;
-      while (it != e && !isa<AffineForOp>(&*it))
-        ++it;
-      if (it != subBlockStart)
-        subBlocks.emplace_back(subBlockStart, std::prev(it));
-      // Process all for ops that appear next.
-      while (it != e && isa<AffineForOp>(&*it))
-        walk(&*it++);
-    }
-  }
-};
-
 /// Unrolls and jams this loop by the specified factor.
 LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
                                                   uint64_t unrollJamFactor) {
@@ -1159,7 +1132,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
     return failure();
 
   // Gather all sub-blocks to jam upon the loop being unrolled.
-  JamBlockGatherer jbg;
+  JamBlockGatherer<AffineForOp> jbg;
   jbg.walk(forOp);
   auto &subBlocks = jbg.subBlocks;
 
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 69f83d8bd70da..df5701a40dc1b 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -304,6 +304,28 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// LoopUnrollAndJamOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *op,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  LogicalResult result(failure());
+  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
+    result = loopUnrollJamByFactor(scfFor, getFactor());
+  else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
+    result = loopUnrollJamByFactor(affineFor, getFactor());
+
+  if (failed(result)) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "failed to unroll and jam";
+    return diag;
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LoopCoalesceOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 6658cca03eba7..07ae3507c3ded 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/LoopUtils.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Support/MathExtras.h"
@@ -26,9 +27,15 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include <cstdint>
 
 using namespace mlir;
 
+#define DEBUG_TYPE "scf-utils"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 namespace {
 // This structure is to pass and return sets of loop parameters without
 // confusing the order.
@@ -296,6 +303,23 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
   return builder.create<arith::DivUIOp>(loc, sum, divisor);
 }
 
+static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
+  std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
+  std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
+  std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
+  if (lbCstOp && ubCstOp && stepCstOp) {
+    // Constant loop bounds computation.
+    int64_t lbCst = lbCstOp.value();
+    int64_t ubCst = ubCstOp.value();
+    int64_t stepCst = stepCstOp.value();
+    assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
+           "expected positive loop bounds and step");
+    return mlir::ceilDiv(ubCst - lbCst, stepCst);
+  }
+
+  return {};
+}
+
 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -375,22 +399,21 @@ LogicalResult mlir::loopUnrollByFactor(
   std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
   std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
   std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
-  if (lbCstOp && ubCstOp && stepCstOp) {
+  auto constTripCount = getConstantTripCount(forOp);
+  if (constTripCount) {
     // Constant loop bounds computation.
     int64_t lbCst = lbCstOp.value();
     int64_t ubCst = ubCstOp.value();
     int64_t stepCst = stepCstOp.value();
-    assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
-           "expected positive loop bounds and step");
-    int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
-
     if (unrollFactor == 1) {
-      if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter)))
+      if (*constTripCount == 1 &&
+          failed(forOp.promoteIfSingleIteration(rewriter)))
         return failure();
       return success();
     }
 
-    int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
+    int64_t tripCountEvenMultiple =
+        *constTripCount - (*constTripCount % unrollFactor);
     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
     int64_t stepUnrolledCst = stepCst * unrollFactor;
 
@@ -473,6 +496,185 @@ LogicalResult mlir::loopUnrollByFactor(
   return success();
 }
 
+/// Check if bounds of all inner loops are defined outside of `forOp`
+/// and return false if not.
+static bool areInnerBoundsInvariant(scf::ForOp forOp) {
+  auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
+    if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+      return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+  return !walkResult.wasInterrupted();
+}
+
+/// Unrolls and jams this loop by the specified factor.
+LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
+                                          uint64_t unrollJamFactor) {
+  assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
+
+  if (unrollJamFactor == 1)
+    return success();
+
+  // If any control operand of any inner loop of `forOp` is defined within
+  // `forOp`, no unroll jam.
+  if (!areInnerBoundsInvariant(forOp)) {
+    LDBG("failed to unroll and jam: inner bounds are not invariant");
+    return failure();
+  }
+
+  // Currently, for operations with results are not supported.
+  if (forOp->getNumResults() > 0) {
+    LDBG("failed to unroll and jam: unsupported loop with results");
+    return failure();
+  }
+
+  // Currently, only constant trip count that divided by the unroll factor is
+  // supported.
+  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  if (!tripCount) {
+    // If the trip count is dynamic, do not unroll & jam.
+    LDBG("failed to unroll and jam: trip count could not be determined");
+    return failure();
+  }
+  if (unrollJamFactor > *tripCount) {
+    LDBG("unroll and jam factor is greater than trip count, set factor to trip "
+         "count");
+    unrollJamFactor = *tripCount;
+  } else if (*tripCount % unrollJamFactor != 0) {
+    LDBG("failed to unroll and jam: unsupported trip count that is not a "
+         "multiple of unroll jam factor");
+    return failure();
+  }
+
+  // Nothing in the loop body other than the terminator.
+  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
+    return success();
+
+  // Gather all sub-blocks to jam upon the loop being unrolled.
+  JamBlockGatherer<scf::ForOp> jbg;
+  jbg.walk(forOp);
+  auto &subBlocks = jbg.subBlocks;
+
+  // Collect inner loops.
+  SmallVector<scf::ForOp, 4> innerLoops;
+  forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
+
+  // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
+  // iteration. There are (`unrollJamFactor` - 1) iterations.
+  SmallVector<IRMapping, 4> operandMaps(unrollJamFactor - 1);
+
+  // For any loop with iter_args, replace it with a new loop that has
+  // `unrollJamFactor` copies of its iterOperands, iter_args and yield
+  // operands.
+  SmallVector<scf::ForOp, 4> newInnerLoops;
+  IRRewriter rewriter(forOp.getContext());
+  for (scf::ForOp oldForOp : innerLoops) {
+    SmallVector<Value> dupIterOperands, dupYieldOperands;
+    ValueRange oldIterOperands = oldForOp.getInits();
+    ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
+    ValueRange oldYieldOperands =
+        cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
+    // Get additional iterOperands, iterArgs, and yield operands. We will
+    // fix iterOperands and yield operands after cloning of sub-blocks.
+    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+      dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
+      dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
+    }
+    // Create a new loop with additional iterOperands, iter_args and yield
+    // operands. This new loop will take the loop body of the original loop.
+    bool forOpReplaced = oldForOp == forOp;
+    scf::ForOp newForOp =
+        cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
+            rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
+            [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+              return dupYieldOperands;
+            }));
+    newInnerLoops.push_back(newForOp);
+    // `forOp` has been replaced with a new loop.
+    if (forOpReplaced)
+      forOp = newForOp;
+    // Update `operandMaps` for `newForOp` iterArgs and results.
+    ValueRange newIterArgs = newForOp.getRegionIterArgs();
+    unsigned oldNumIterArgs = oldIterArgs.size();
+    ValueRange newResults = newForOp.getResults();
+    unsigned oldNumResults = newResults.size() / unrollJamFactor;
+    assert(oldNumIterArgs == oldNumResults &&
+           "oldNumIterArgs must be the same as oldNumResults");
+    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+      for (unsigned j = 0; j < oldNumIterArgs; ++j) {
+        // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
+        // results. Update `operandMaps[i - 1]` to map old iterArgs and results
+        // to those in the `i`th new set.
+        operandMaps[i - 1].map(newIterArgs[j],
+                               newIterArgs[i * oldNumIterArgs + j]);
+        operandMaps[i - 1].map(newResults[j],
+                               newResults[i * oldNumResults + j]);
+      }
+    }
+  }
+
+  // Scale the step of loop being unroll-jammed by the unroll-jam factor.
+  rewriter.setInsertionPoint(forOp);
+  int64_t step = forOp.getConstantStep()->getSExtValue();
+  auto newStep = rewriter.createOrFold<arith::MulIOp>(
+      forOp.getLoc(), forOp.getStep(),
+      rewriter.createOrFold<arith::ConstantOp>(
+          forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
+  forOp.setStep(newStep);
+  auto forOpIV = forOp.getInductionVar();
+
+  // Unroll and jam (appends unrollJamFactor - 1 additional copies).
+  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+    for (auto &subBlock : subBlocks) {
+      // Builder to insert unroll-jammed bodies. Insert right at the end of
+      // sub-block.
+      OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
+
+      // If the induction variable is used, create a remapping to the value for
+      // this unrolled instance.
+      if (!forOpIV.use_empty()) {
+        // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
+        auto ivTag = builder.createOrFold<arith::ConstantOp>(
+            forOp.getLoc(), builder.getIndexAttr(step * i));
+        auto ivUnroll =
+            builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
+        operandMaps[i - 1].map(forOpIV, ivUnroll);
+      }
+      // Clone the sub-block being unroll-jammed.
+      for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
+        builder.clone(*it, operandMaps[i - 1]);
+    }
+    // Fix iterOperands and yield op operands of newly created loops.
+    for (auto newForOp : newInnerLoops) {
+      unsigned oldNumIterOperands =
+          newForOp.getNumRegionIterArgs() / unrollJamFactor;
+      unsigned numControlOperands = newForOp.getNumControlOperands();
+      auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+      unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
+      assert(oldNumIterOperands == oldNumYieldOperands &&
+             "oldNumIterOperands must be the same as oldNumYieldOperands");
+      for (unsigned j = 0; j < oldNumIterOperands; ++j) {
+        // The `i`th duplication of an old iterOperand or yield op operand
+        // needs to be replaced with a mapped value from `operandMaps[i - 1]`
+        // if such mapped value exists.
+        newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
+                            operandMaps[i - 1].lookupOrDefault(
+                   ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 2, 2024

@llvm/pr-subscribers-mlir-affine

Author: Aviad Cohen (AviadCo)

Changes

Unroll And Jam was supported in affine dialect long time ago using pass. This commit exposes the pattern using transform and in addition adds partial support for SCF loops.


Patch is 37.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94142.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td (+35-1)
  • (modified) mlir/include/mlir/Dialect/SCF/Utils/Utils.h (+8)
  • (added) mlir/include/mlir/IR/LoopUtils.h (+56)
  • (modified) mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp (+2-29)
  • (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+22)
  • (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+209-7)
  • (modified) mlir/test/Dialect/SCF/transform-ops-invalid.mlir (+99)
  • (modified) mlir/test/Dialect/SCF/transform-ops.mlir (+215)
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 5eefe2664d0a1..29651bed296e5 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -243,7 +243,41 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
     This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
     in the return. If all the operations referred to by the `target` operand
     unroll properly, the transform succeeds. Otherwise the transform produces a
-    silencebale failure.
+    silenceable failure.
+
+    Does not return handles as the operation may result in the loop being
+    removed after a full unrolling.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       ConfinedAttr<I64Attr, [IntPositive]>:$factor);
+
+  let assemblyFormat = "$target attr-dict `:` type($target)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+def LoopUnrollAndJamOp : Op<Transform_Dialect, "loop.unroll_and_jam",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let summary = "Unrolls and jam the given loop with the given unroll factor";
+  let description = [{
+    Unrolls & jams each loop associated with the given handle to have up to the given
+    number of loop body copies per iteration. If the unroll factor is larger
+    than the loop trip count, the latter is used as the unroll factor instead.
+
+    #### Return modes
+
+    This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
+    in the return. If all the operations referred to by the `target` operand
+    unroll properly, the transform succeeds. Otherwise the transform produces a
+    silenceable failure.
 
     Does not return handles as the operation may result in the loop being
     removed after a full unrolling.
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index bc09cc7f7fa5e..ad3113c960d37 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -120,6 +120,14 @@ LogicalResult loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
+/// Unrolls and jam this for operation by the specified unroll factor. Returns
+/// failure if the loop cannot be unrolled either due to restrictions or due to
+/// invalid unroll factors. In case of unroll factor of 1, the function bails
+/// out without doing anything (returns success). Currently, only constant trip
+/// count that are divided by the unroll factor is supported. Currently, for
+/// operations with results are not supported.
+LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor);
+
 /// Tile a nest of standard for loops rooted at `rootForOp` by finding such
 /// parametric tile sizes that the outer loops have a fixed number of iterations
 /// as defined in `sizes`.
diff --git a/mlir/include/mlir/IR/LoopUtils.h b/mlir/include/mlir/IR/LoopUtils.h
new file mode 100644
index 0000000000000..410eec223e637
--- /dev/null
+++ b/mlir/include/mlir/IR/LoopUtils.h
@@ -0,0 +1,56 @@
+//===- LoopUtils.h -  LoopUtils Support ---------------------*- C++
+//-*-=============//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions for the action framework. This framework
+// allows for external entities to control certain actions taken by the compiler
+// by registering handler functions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_LOOP_UTILS_H
+#define MLIR_IR_LOOP_UTILS_H
+
+#include "mlir/IR/BuiltinOps.h"
+
+namespace mlir {
+
+// Gathers all maximal sub-blocks of operations that do not themselves
+// include a `OpTy` (an operation could have a descendant `OpTy` though
+// in its tree). Ignore the block terminators.
+template <typename OpTy>
+struct JamBlockGatherer {
+  // Store iterators to the first and last op of each sub-block found.
+  llvm::SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
+
+  // This is a linear time walk.
+  void walk(Operation *op) {
+    for (auto &region : op->getRegions())
+      for (auto &block : region)
+        walk(block);
+  }
+
+  void walk(Block &block) {
+    assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
+           "expected block to have a terminator");
+    for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
+      auto subBlockStart = it;
+      while (it != e && !isa<OpTy>(&*it))
+        ++it;
+      if (it != subBlockStart)
+        subBlocks.emplace_back(subBlockStart, std::prev(it));
+      // Process all for ops that appear next.
+      while (it != e && isa<OpTy>(&*it))
+        walk(&*it++);
+    }
+  }
+};
+
+} // namespace mlir
+
+#endif // MLIR_IR_LOOP_UTILS_H
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 268050a30e002..6d35f06faa45b 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/LoopUtils.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/MapVector.h"
@@ -1102,34 +1103,6 @@ static bool areInnerBoundsInvariant(AffineForOp forOp) {
   return !walkResult.wasInterrupted();
 }
 
-// Gathers all maximal sub-blocks of operations that do not themselves
-// include a for op (a operation could have a descendant for op though
-// in its tree).  Ignore the block terminators.
-struct JamBlockGatherer {
-  // Store iterators to the first and last op of each sub-block found.
-  std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
-
-  // This is a linear time walk.
-  void walk(Operation *op) {
-    for (auto &region : op->getRegions())
-      for (auto &block : region)
-        walk(block);
-  }
-
-  void walk(Block &block) {
-    for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
-      auto subBlockStart = it;
-      while (it != e && !isa<AffineForOp>(&*it))
-        ++it;
-      if (it != subBlockStart)
-        subBlocks.emplace_back(subBlockStart, std::prev(it));
-      // Process all for ops that appear next.
-      while (it != e && isa<AffineForOp>(&*it))
-        walk(&*it++);
-    }
-  }
-};
-
 /// Unrolls and jams this loop by the specified factor.
 LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
                                                   uint64_t unrollJamFactor) {
@@ -1159,7 +1132,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
     return failure();
 
   // Gather all sub-blocks to jam upon the loop being unrolled.
-  JamBlockGatherer jbg;
+  JamBlockGatherer<AffineForOp> jbg;
   jbg.walk(forOp);
   auto &subBlocks = jbg.subBlocks;
 
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 69f83d8bd70da..df5701a40dc1b 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -304,6 +304,28 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// LoopUnrollAndJamOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *op,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  LogicalResult result(failure());
+  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
+    result = loopUnrollJamByFactor(scfFor, getFactor());
+  else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
+    result = loopUnrollJamByFactor(affineFor, getFactor());
+
+  if (failed(result)) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "failed to unroll and jam";
+    return diag;
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LoopCoalesceOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 6658cca03eba7..07ae3507c3ded 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/LoopUtils.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Support/MathExtras.h"
@@ -26,9 +27,15 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include <cstdint>
 
 using namespace mlir;
 
+#define DEBUG_TYPE "scf-utils"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 namespace {
 // This structure is to pass and return sets of loop parameters without
 // confusing the order.
@@ -296,6 +303,23 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
   return builder.create<arith::DivUIOp>(loc, sum, divisor);
 }
 
+static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
+  std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
+  std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
+  std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
+  if (lbCstOp && ubCstOp && stepCstOp) {
+    // Constant loop bounds computation.
+    int64_t lbCst = lbCstOp.value();
+    int64_t ubCst = ubCstOp.value();
+    int64_t stepCst = stepCstOp.value();
+    assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
+           "expected positive loop bounds and step");
+    return mlir::ceilDiv(ubCst - lbCst, stepCst);
+  }
+
+  return {};
+}
+
 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -375,22 +399,21 @@ LogicalResult mlir::loopUnrollByFactor(
   std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
   std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
   std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
-  if (lbCstOp && ubCstOp && stepCstOp) {
+  auto constTripCount = getConstantTripCount(forOp);
+  if (constTripCount) {
     // Constant loop bounds computation.
     int64_t lbCst = lbCstOp.value();
     int64_t ubCst = ubCstOp.value();
     int64_t stepCst = stepCstOp.value();
-    assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
-           "expected positive loop bounds and step");
-    int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
-
     if (unrollFactor == 1) {
-      if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter)))
+      if (*constTripCount == 1 &&
+          failed(forOp.promoteIfSingleIteration(rewriter)))
         return failure();
       return success();
     }
 
-    int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
+    int64_t tripCountEvenMultiple =
+        *constTripCount - (*constTripCount % unrollFactor);
     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
     int64_t stepUnrolledCst = stepCst * unrollFactor;
 
@@ -473,6 +496,185 @@ LogicalResult mlir::loopUnrollByFactor(
   return success();
 }
 
+/// Check if bounds of all inner loops are defined outside of `forOp`
+/// and return false if not.
+static bool areInnerBoundsInvariant(scf::ForOp forOp) {
+  auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
+    if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+      return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+  return !walkResult.wasInterrupted();
+}
+
+/// Unrolls and jams this loop by the specified factor.
+LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
+                                          uint64_t unrollJamFactor) {
+  assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
+
+  if (unrollJamFactor == 1)
+    return success();
+
+  // If any control operand of any inner loop of `forOp` is defined within
+  // `forOp`, no unroll jam.
+  if (!areInnerBoundsInvariant(forOp)) {
+    LDBG("failed to unroll and jam: inner bounds are not invariant");
+    return failure();
+  }
+
+  // Currently, for operations with results are not supported.
+  if (forOp->getNumResults() > 0) {
+    LDBG("failed to unroll and jam: unsupported loop with results");
+    return failure();
+  }
+
+  // Currently, only constant trip count that divided by the unroll factor is
+  // supported.
+  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  if (!tripCount) {
+    // If the trip count is dynamic, do not unroll & jam.
+    LDBG("failed to unroll and jam: trip count could not be determined");
+    return failure();
+  }
+  if (unrollJamFactor > *tripCount) {
+    LDBG("unroll and jam factor is greater than trip count, set factor to trip "
+         "count");
+    unrollJamFactor = *tripCount;
+  } else if (*tripCount % unrollJamFactor != 0) {
+    LDBG("failed to unroll and jam: unsupported trip count that is not a "
+         "multiple of unroll jam factor");
+    return failure();
+  }
+
+  // Nothing in the loop body other than the terminator.
+  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
+    return success();
+
+  // Gather all sub-blocks to jam upon the loop being unrolled.
+  JamBlockGatherer<scf::ForOp> jbg;
+  jbg.walk(forOp);
+  auto &subBlocks = jbg.subBlocks;
+
+  // Collect inner loops.
+  SmallVector<scf::ForOp, 4> innerLoops;
+  forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
+
+  // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
+  // iteration. There are (`unrollJamFactor` - 1) iterations.
+  SmallVector<IRMapping, 4> operandMaps(unrollJamFactor - 1);
+
+  // For any loop with iter_args, replace it with a new loop that has
+  // `unrollJamFactor` copies of its iterOperands, iter_args and yield
+  // operands.
+  SmallVector<scf::ForOp, 4> newInnerLoops;
+  IRRewriter rewriter(forOp.getContext());
+  for (scf::ForOp oldForOp : innerLoops) {
+    SmallVector<Value> dupIterOperands, dupYieldOperands;
+    ValueRange oldIterOperands = oldForOp.getInits();
+    ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
+    ValueRange oldYieldOperands =
+        cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
+    // Get additional iterOperands, iterArgs, and yield operands. We will
+    // fix iterOperands and yield operands after cloning of sub-blocks.
+    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+      dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
+      dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
+    }
+    // Create a new loop with additional iterOperands, iter_args and yield
+    // operands. This new loop will take the loop body of the original loop.
+    bool forOpReplaced = oldForOp == forOp;
+    scf::ForOp newForOp =
+        cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
+            rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
+            [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+              return dupYieldOperands;
+            }));
+    newInnerLoops.push_back(newForOp);
+    // `forOp` has been replaced with a new loop.
+    if (forOpReplaced)
+      forOp = newForOp;
+    // Update `operandMaps` for `newForOp` iterArgs and results.
+    ValueRange newIterArgs = newForOp.getRegionIterArgs();
+    unsigned oldNumIterArgs = oldIterArgs.size();
+    ValueRange newResults = newForOp.getResults();
+    unsigned oldNumResults = newResults.size() / unrollJamFactor;
+    assert(oldNumIterArgs == oldNumResults &&
+           "oldNumIterArgs must be the same as oldNumResults");
+    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+      for (unsigned j = 0; j < oldNumIterArgs; ++j) {
+        // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
+        // results. Update `operandMaps[i - 1]` to map old iterArgs and results
+        // to those in the `i`th new set.
+        operandMaps[i - 1].map(newIterArgs[j],
+                               newIterArgs[i * oldNumIterArgs + j]);
+        operandMaps[i - 1].map(newResults[j],
+                               newResults[i * oldNumResults + j]);
+      }
+    }
+  }
+
+  // Scale the step of loop being unroll-jammed by the unroll-jam factor.
+  rewriter.setInsertionPoint(forOp);
+  int64_t step = forOp.getConstantStep()->getSExtValue();
+  auto newStep = rewriter.createOrFold<arith::MulIOp>(
+      forOp.getLoc(), forOp.getStep(),
+      rewriter.createOrFold<arith::ConstantOp>(
+          forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
+  forOp.setStep(newStep);
+  auto forOpIV = forOp.getInductionVar();
+
+  // Unroll and jam (appends unrollJamFactor - 1 additional copies).
+  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+    for (auto &subBlock : subBlocks) {
+      // Builder to insert unroll-jammed bodies. Insert right at the end of
+      // sub-block.
+      OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
+
+      // If the induction variable is used, create a remapping to the value for
+      // this unrolled instance.
+      if (!forOpIV.use_empty()) {
+        // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
+        auto ivTag = builder.createOrFold<arith::ConstantOp>(
+            forOp.getLoc(), builder.getIndexAttr(step * i));
+        auto ivUnroll =
+            builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
+        operandMaps[i - 1].map(forOpIV, ivUnroll);
+      }
+      // Clone the sub-block being unroll-jammed.
+      for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
+        builder.clone(*it, operandMaps[i - 1]);
+    }
+    // Fix iterOperands and yield op operands of newly created loops.
+    for (auto newForOp : newInnerLoops) {
+      unsigned oldNumIterOperands =
+          newForOp.getNumRegionIterArgs() / unrollJamFactor;
+      unsigned numControlOperands = newForOp.getNumControlOperands();
+      auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+      unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
+      assert(oldNumIterOperands == oldNumYieldOperands &&
+             "oldNumIterOperands must be the same as oldNumYieldOperands");
+      for (unsigned j = 0; j < oldNumIterOperands; ++j) {
+        // The `i`th duplication of an old iterOperand or yield op operand
+        // needs to be replaced with a mapped value from `operandMaps[i - 1]`
+        // if such mapped value exists.
+        newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
+                            operandMaps[i - 1].lookupOrDefault(
+                   ...
[truncated]

@AviadCo AviadCo requested a review from MaheshRavishankar June 2, 2024 06:17
@AviadCo AviadCo force-pushed the transform/unroll-and-jam branch from f52e7b2 to c596182 Compare June 2, 2024 07:41
@AviadCo AviadCo requested a review from matthias-springer June 4, 2024 08:34
Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Please address the comments.

Unroll And Jam was supported in affine dialect long time ago using pass.
This commit exposes the pattern using transform and in addition adds
partial support for SCF loops.
@AviadCo AviadCo force-pushed the transform/unroll-and-jam branch from c596182 to e02c2ab Compare June 17, 2024 03:27
@AviadCo
Copy link
Contributor Author

AviadCo commented Jun 17, 2024

@ftynse @wyzero thanks for the review! answered your comments.

Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
@AviadCo AviadCo merged commit cb8bd6f into llvm:main Jun 21, 2024
7 checks passed
@AviadCo AviadCo deleted the transform/unroll-and-jam branch June 21, 2024 12:48
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
…m#94142)

Unroll And Jam was supported in affine dialect long time ago using pass.
This commit exposes the pattern using transform and in addition adds
partial support for SCF loops.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants