Skip to content

[mlir][affine] re-land implement promoteIfSingleIteration for AffineForOp #72805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Nov 19, 2023

I had to revert #72547 because I didn't notice a dep on func::FuncOp in promoteIfSingleIteration:

if (forOp.hasConstantLowerBound()) {
  OpBuilder topBuilder(forOp->getParentOfType<func::FuncOp>().getBody());
  auto constOp = topBuilder.create<arith::ConstantIndexOp>(
      forOp.getLoc(), forOp.getConstantLowerBound());

I.e., hoist the arith.constant to the nearest func. The alternative I implemented here

if (forOp.hasConstantLowerBound()) {
  Operation *parentOp = forOp.getOperation();
  while (isa<AffineForOp>(parentOp->getParentOp()))
    parentOp = parentOp->getParentOp();
  Block *parentBlock = parentOp->getBlock();
  OpBuilder topBuilder(parentBlock, parentBlock->begin());

I.e., hoist to the beginning of the first op that isn't a loop. Tests pass but obviously it's not equivalent.

In general, this should just be handled by cse?

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2023

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

I had to revert #72547 because I didn't notice a dep on func::FuncOp in promoteIfSingleIteration:

if (forOp.hasConstantLowerBound()) {
  OpBuilder topBuilder(forOp-&gt;getParentOfType&lt;func::FuncOp&gt;().getBody());
  auto constOp = topBuilder.create&lt;arith::ConstantIndexOp&gt;(
      forOp.getLoc(), forOp.getConstantLowerBound());

I.e., hoist the arith.constant to the nearest func. The alternative I implemented here

if (forOp.hasConstantLowerBound()) {
  Operation *parentOp = forOp.getOperation();
  while (isa&lt;AffineForOp&gt;(parentOp-&gt;getParentOp()))
    parentOp = parentOp-&gt;getParentOp();
  Block *parentBlock = parentOp-&gt;getBlock();
  OpBuilder topBuilder(parentBlock, parentBlock-&gt;begin());

but just wanted to make sure.


Patch is 29.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72805.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h (+1-14)
  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.h (+25-4)
  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Affine/LoopUtils.h (-4)
  • (modified) mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp (+1-79)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+144-7)
  • (modified) mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp (+13-70)
  • (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+2-1)
  • (modified) mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp (+2-1)
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h b/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
index 92f3d5a2c4925b1..c629c3a1c562322 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
+
 #include <optional>
 
 namespace mlir {
@@ -29,20 +30,6 @@ namespace affine {
 class AffineForOp;
 class NestedPattern;
 
-/// Returns the trip count of the loop as an affine map with its corresponding
-/// operands if the latter is expressible as an affine expression, and nullptr
-/// otherwise. This method always succeeds as long as the lower bound is not a
-/// multi-result map. The trip count expression is simplified before returning.
-/// This method only utilizes map composition to construct lower and upper
-/// bounds before computing the trip count expressions
-void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *map,
-                                SmallVectorImpl<Value> *operands);
-
-/// Returns the trip count of the loop if it's a constant, std::nullopt
-/// otherwise. This uses affine expression analysis and is able to determine
-/// constant trip count in non-trivial cases.
-std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);
-
 /// Returns the greatest known integral divisor of the trip count. Affine
 /// expression analysis is used (indirectly through getTripCount), and
 /// this method is thus able to determine non-trivial divisors.
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index f070d0488619063..f763cf339159a50 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -117,7 +117,8 @@ class AffineDmaStartOp
   /// Returns the affine map used to access the source memref.
   AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
   AffineMapAttr getSrcMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getSrcMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getSrcMapAttrStrName()));
   }
 
   /// Returns the source memref affine map indices for this DMA operation.
@@ -156,7 +157,8 @@ class AffineDmaStartOp
   /// Returns the affine map used to access the destination memref.
   AffineMap getDstMap() { return getDstMapAttr().getValue(); }
   AffineMapAttr getDstMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getDstMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getDstMapAttrStrName()));
   }
 
   /// Returns the destination memref indices for this DMA operation.
@@ -185,7 +187,8 @@ class AffineDmaStartOp
   /// Returns the affine map used to access the tag memref.
   AffineMap getTagMap() { return getTagMapAttr().getValue(); }
   AffineMapAttr getTagMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getTagMapAttrStrName()));
   }
 
   /// Returns the tag memref indices for this DMA operation.
@@ -307,7 +310,8 @@ class AffineDmaWaitOp
   /// Returns the affine map used to access the tag memref.
   AffineMap getTagMap() { return getTagMapAttr().getValue(); }
   AffineMapAttr getTagMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getTagMapAttrStrName()));
   }
 
   /// Returns the tag memref index for this DMA operation.
@@ -465,6 +469,23 @@ AffineForOp getForInductionVarOwner(Value val);
 /// AffineParallelOp.
 AffineParallelOp getAffineParallelInductionVarOwner(Value val);
 
+/// Helper to replace uses of loop carried values (iter_args) and loop
+/// yield values while promoting single iteration affine.for ops.
+void replaceIterArgsAndYieldResults(AffineForOp forOp);
+
+/// Returns the trip count of the loop as an affine expression if the latter is
+/// expressible as an affine expression, and nullptr otherwise. The trip count
+/// expression is simplified before returning. This method only utilizes map
+/// composition to construct lower and upper bounds before computing the trip
+/// count expressions.
+void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *tripCountMap,
+                                SmallVectorImpl<Value> *tripCountOperands);
+
+/// Returns the trip count of the loop if it's a constant, std::nullopt
+/// otherwise. This uses affine expression analysis and is able to determine
+/// constant trip count in non-trivial cases.
+std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);
+
 /// Extracts the induction variables from a list of AffineForOps and places them
 /// in the output argument `ivs`.
 void extractForInductionVars(ArrayRef<AffineForOp> forInsts,
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index f9578cf37d5d768..b4ea6122ed4c0e0 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -121,7 +121,7 @@ def AffineForOp : Affine_Op<"for",
      ImplicitAffineTerminator, ConditionallySpeculatable,
      RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
      ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
-      "getSingleUpperBound", "getYieldedValuesMutable",
+      "getSingleUpperBound", "getYieldedValuesMutable", "promoteIfSingleIteration",
       "replaceWithAdditionalYields"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
      ["getEntrySuccessorOperands"]>]> {
diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
index 723a262f24acc51..1e3b3bffea7b838 100644
--- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
@@ -83,10 +83,6 @@ LogicalResult loopUnrollJamByFactor(AffineForOp forOp,
 LogicalResult loopUnrollJamUpToFactor(AffineForOp forOp,
                                       uint64_t unrollJamFactor);
 
-/// Promotes the loop body of a AffineForOp to its containing block if the loop
-/// was known to have a single iteration.
-LogicalResult promoteIfSingleIteration(AffineForOp forOp);
-
 /// Promotes all single iteration AffineForOp's in the Function, i.e., moves
 /// their body into the containing Block.
 void promoteSingleIterationLoops(func::FuncOp f);
diff --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
index e645afe7cd3e8fa..24f119464b416a7 100644
--- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
@@ -12,7 +12,6 @@
 
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 
-#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
 #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
@@ -20,9 +19,9 @@
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Support/MathExtras.h"
 
-#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallString.h"
+
 #include <numeric>
 #include <optional>
 #include <type_traits>
@@ -30,83 +29,6 @@
 using namespace mlir;
 using namespace mlir::affine;
 
-/// Returns the trip count of the loop as an affine expression if the latter is
-/// expressible as an affine expression, and nullptr otherwise. The trip count
-/// expression is simplified before returning. This method only utilizes map
-/// composition to construct lower and upper bounds before computing the trip
-/// count expressions.
-void mlir::affine::getTripCountMapAndOperands(
-    AffineForOp forOp, AffineMap *tripCountMap,
-    SmallVectorImpl<Value> *tripCountOperands) {
-  MLIRContext *context = forOp.getContext();
-  int64_t step = forOp.getStepAsInt();
-  int64_t loopSpan;
-  if (forOp.hasConstantBounds()) {
-    int64_t lb = forOp.getConstantLowerBound();
-    int64_t ub = forOp.getConstantUpperBound();
-    loopSpan = ub - lb;
-    if (loopSpan < 0)
-      loopSpan = 0;
-    *tripCountMap = AffineMap::getConstantMap(ceilDiv(loopSpan, step), context);
-    tripCountOperands->clear();
-    return;
-  }
-  auto lbMap = forOp.getLowerBoundMap();
-  auto ubMap = forOp.getUpperBoundMap();
-  if (lbMap.getNumResults() != 1) {
-    *tripCountMap = AffineMap();
-    return;
-  }
-
-  // Difference of each upper bound expression from the single lower bound
-  // expression (divided by the step) provides the expressions for the trip
-  // count map.
-  AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands());
-
-  SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
-                                         lbMap.getResult(0));
-  auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
-                                   lbSplatExpr, context);
-  AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());
-
-  AffineValueMap tripCountValueMap;
-  AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap);
-  for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
-    tripCountValueMap.setResult(i,
-                                tripCountValueMap.getResult(i).ceilDiv(step));
-
-  *tripCountMap = tripCountValueMap.getAffineMap();
-  tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
-                            tripCountValueMap.getOperands().end());
-}
-
-/// Returns the trip count of the loop if it's a constant, std::nullopt
-/// otherwise. This method uses affine expression analysis (in turn using
-/// getTripCount) and is able to determine constant trip count in non-trivial
-/// cases.
-std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
-  SmallVector<Value, 4> operands;
-  AffineMap map;
-  getTripCountMapAndOperands(forOp, &map, &operands);
-
-  if (!map)
-    return std::nullopt;
-
-  // Take the min if all trip counts are constant.
-  std::optional<uint64_t> tripCount;
-  for (auto resultExpr : map.getResults()) {
-    if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
-      if (tripCount.has_value())
-        tripCount =
-            std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
-      else
-        tripCount = constExpr.getValue();
-    } else
-      return std::nullopt;
-  }
-  return tripCount;
-}
-
 /// Returns the greatest known integral divisor of the trip count. Affine
 /// expression analysis is used (indirectly through getTripCount), and
 /// this method is thus able to determine non-trivial divisors.
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 05496e70716a2a1..8716d7a3525b526 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/AffineExprVisitor.h"
@@ -23,6 +24,7 @@
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+
 #include <numeric>
 #include <optional>
 
@@ -2440,6 +2442,69 @@ std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
   return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
 }
 
+void mlir::affine::replaceIterArgsAndYieldResults(AffineForOp forOp) {
+  // Replace uses of iter arguments with iter operands (initial values).
+  OperandRange iterOperands = forOp.getInits();
+  MutableArrayRef<BlockArgument> iterArgs = forOp.getRegionIterArgs();
+  for (auto [operand, arg] : llvm::zip(iterOperands, iterArgs))
+    arg.replaceAllUsesWith(operand);
+
+  // Replace uses of loop results with the values yielded by the loop.
+  ResultRange outerResults = forOp.getResults();
+  OperandRange innerResults = forOp.getBody()->getTerminator()->getOperands();
+  for (auto [outer, inner] : llvm::zip(outerResults, innerResults))
+    outer.replaceAllUsesWith(inner);
+}
+
+LogicalResult AffineForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
+  auto forOp = cast<AffineForOp>(getOperation());
+  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  if (!tripCount || *tripCount != 1)
+    return failure();
+
+  // TODO: extend this for arbitrary affine bounds.
+  if (forOp.getLowerBoundMap().getNumResults() != 1)
+    return failure();
+
+  // Replaces all IV uses to its single iteration value.
+  BlockArgument iv = forOp.getInductionVar();
+  if (!iv.use_empty()) {
+    if (forOp.hasConstantLowerBound()) {
+      Operation *parentOp = forOp.getOperation();
+      while (isa<AffineForOp>(parentOp->getParentOp()))
+        parentOp = parentOp->getParentOp();
+      Block *parentBlock = parentOp->getBlock();
+      OpBuilder topBuilder(parentBlock, parentBlock->begin());
+      auto constOp = topBuilder.create<arith::ConstantIndexOp>(
+          forOp.getLoc(), forOp.getConstantLowerBound());
+      iv.replaceAllUsesWith(constOp);
+    } else {
+      OperandRange lbOperands = forOp.getLowerBoundOperands();
+      AffineMap lbMap = forOp.getLowerBoundMap();
+      OpBuilder builder(forOp);
+      if (lbMap == builder.getDimIdentityMap()) {
+        // No need of generating an affine.apply.
+        iv.replaceAllUsesWith(lbOperands[0]);
+      } else {
+        auto affineApplyOp =
+            builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
+        iv.replaceAllUsesWith(affineApplyOp);
+      }
+    }
+  }
+
+  replaceIterArgsAndYieldResults(forOp);
+
+  // Move the loop body operations, except for its terminator, to the loop's
+  // containing block.
+  forOp.getBody()->back().erase();
+  Block *parentBlock = forOp->getBlock();
+  parentBlock->getOperations().splice(Block::iterator(forOp),
+                                      forOp.getBody()->getOperations());
+  forOp.erase();
+  return success();
+}
+
 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
     RewriterBase &rewriter, ValueRange newInitOperands,
     bool replaceInitOperandUsesInLoop,
@@ -2538,6 +2603,79 @@ AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) {
   return nullptr;
 }
 
+/// Returns the trip count of the loop as an affine expression if the latter is
+/// expressible as an affine expression, and nullptr otherwise. The trip count
+/// expression is simplified before returning. This method only utilizes map
+/// composition to construct lower and upper bounds before computing the trip
+/// count expressions.
+void mlir::affine::getTripCountMapAndOperands(
+    AffineForOp forOp, AffineMap *tripCountMap,
+    SmallVectorImpl<Value> *tripCountOperands) {
+  MLIRContext *context = forOp.getContext();
+  int64_t step = forOp.getStepAsInt();
+  int64_t loopSpan;
+  if (forOp.hasConstantBounds()) {
+    int64_t lb = forOp.getConstantLowerBound();
+    int64_t ub = forOp.getConstantUpperBound();
+    loopSpan = ub - lb;
+    if (loopSpan < 0)
+      loopSpan = 0;
+    *tripCountMap = AffineMap::getConstantMap(ceilDiv(loopSpan, step), context);
+    tripCountOperands->clear();
+    return;
+  }
+  auto lbMap = forOp.getLowerBoundMap();
+  auto ubMap = forOp.getUpperBoundMap();
+  if (lbMap.getNumResults() != 1) {
+    *tripCountMap = AffineMap();
+    return;
+  }
+
+  // Difference of each upper bound expression from the single lower bound
+  // expression (divided by the step) provides the expressions for the trip
+  // count map.
+  AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands());
+
+  SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
+                                         lbMap.getResult(0));
+  auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
+                                   lbSplatExpr, context);
+  AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());
+
+  AffineValueMap tripCountValueMap;
+  AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap);
+  for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
+    tripCountValueMap.setResult(i,
+                                tripCountValueMap.getResult(i).ceilDiv(step));
+
+  *tripCountMap = tripCountValueMap.getAffineMap();
+  tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
+                            tripCountValueMap.getOperands().end());
+}
+
+std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
+  SmallVector<Value, 4> operands;
+  AffineMap map;
+  getTripCountMapAndOperands(forOp, &map, &operands);
+
+  if (!map)
+    return std::nullopt;
+
+  // Take the min if all trip counts are constant.
+  std::optional<uint64_t> tripCount;
+  for (auto resultExpr : map.getResults()) {
+    if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
+      if (tripCount.has_value())
+        tripCount =
+            std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
+      else
+        tripCount = constExpr.getValue();
+    } else
+      return std::nullopt;
+  }
+  return tripCount;
+}
+
 /// Extracts the induction variables from a list of AffineForOps and returns
 /// them.
 void mlir::affine::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
@@ -2905,8 +3043,7 @@ static void composeSetAndOperands(IntegerSet &set,
 }
 
 /// Canonicalize an affine if op's conditional (integer set + operands).
-LogicalResult AffineIfOp::fold(FoldAdaptor,
-                               SmallVectorImpl<OpFoldResult> &) {
+LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
   auto set = getIntegerSet();
   SmallVector<Value, 4> operands(getOperands());
   composeSetAndOperands(set, operands);
@@ -2997,11 +3134,11 @@ static LogicalResult
 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
                        Operation::operand_range mapOperands,
                        MemRefType memrefType, unsigned numIndexOperands) {
-    AffineMap map = mapAttr.getValue();
-    if (map.getNumResults() != memrefType.getRank())
-      return op->emitOpError("affine map num results must equal memref rank");
-    if (map.getNumInputs() != numIndexOperands)
-      return op->emitOpError("expects as many subscripts as affine map inputs");
+  AffineMap map = mapAttr.getValue();
+  if (map.getNumResults() != memrefType.getRank())
+    return op->emitOpError("affine map num results must equal memref rank");
+  if (map.getNumInputs() != numIndexOperands)
+    return op->emitOpError("expects as many subscripts as affine map inputs");
 
   Region *scope = getAffineScope(op);
   for (auto idx : mapOperands) {
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 331b0f1b2c2b1c6..31b90a60472c1f1 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -219,13 +219,14 @@ void AffineDataCopyGeneration::runOnOperation() {
 
   // Promote any single iteration loops in the copy nests and collect
   // load/stores to simplify.
+  IRRewriter rewriter(f.getContext());
   SmallVector<Operation *, 4> copyOps;
   for (Operation *nest : copyNests)
     // With a post order walk, the erasure of loops does not affect
     // continuation of the walk or the collection of load/store ops.
     nest->walk([&](Operation *op) {
       if (auto forOp = dyn_cast<AffineForOp>(op))
-        (void)promoteIfSingleIteration(forOp);
+        (void)forOp.promoteIfSingleIteration(rewriter);
       else if (isa<AffineLoadOp, AffineStoreOp>(op))
         copyOps.push_back(op);
     });
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index 5053b08ee0834cd..d11e77544e24ea5 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -457,6 +457,7 @@ void m...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2023

@llvm/pr-subscribers-mlir-affine

Author: Maksim Levental (makslevental)

Changes

I had to revert #72547 because I didn't notice a dep on func::FuncOp in promoteIfSingleIteration:

if (forOp.hasConstantLowerBound()) {
  OpBuilder topBuilder(forOp-&gt;getParentOfType&lt;func::FuncOp&gt;().getBody());
  auto constOp = topBuilder.create&lt;arith::ConstantIndexOp&gt;(
      forOp.getLoc(), forOp.getConstantLowerBound());

I.e., hoist the arith.constant to the nearest func. The alternative I implemented here

if (forOp.hasConstantLowerBound()) {
  Operation *parentOp = forOp.getOperation();
  while (isa&lt;AffineForOp&gt;(parentOp-&gt;getParentOp()))
    parentOp = parentOp-&gt;getParentOp();
  Block *parentBlock = parentOp-&gt;getBlock();
  OpBuilder topBuilder(parentBlock, parentBlock-&gt;begin());

but just wanted to make sure.


Patch is 29.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72805.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h (+1-14)
  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.h (+25-4)
  • (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Affine/LoopUtils.h (-4)
  • (modified) mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp (+1-79)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+144-7)
  • (modified) mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp (+13-70)
  • (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+2-1)
  • (modified) mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp (+2-1)
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h b/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
index 92f3d5a2c4925b1..c629c3a1c562322 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
+
 #include <optional>
 
 namespace mlir {
@@ -29,20 +30,6 @@ namespace affine {
 class AffineForOp;
 class NestedPattern;
 
-/// Returns the trip count of the loop as an affine map with its corresponding
-/// operands if the latter is expressible as an affine expression, and nullptr
-/// otherwise. This method always succeeds as long as the lower bound is not a
-/// multi-result map. The trip count expression is simplified before returning.
-/// This method only utilizes map composition to construct lower and upper
-/// bounds before computing the trip count expressions
-void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *map,
-                                SmallVectorImpl<Value> *operands);
-
-/// Returns the trip count of the loop if it's a constant, std::nullopt
-/// otherwise. This uses affine expression analysis and is able to determine
-/// constant trip count in non-trivial cases.
-std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);
-
 /// Returns the greatest known integral divisor of the trip count. Affine
 /// expression analysis is used (indirectly through getTripCount), and
 /// this method is thus able to determine non-trivial divisors.
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index f070d0488619063..f763cf339159a50 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -117,7 +117,8 @@ class AffineDmaStartOp
   /// Returns the affine map used to access the source memref.
   AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
   AffineMapAttr getSrcMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getSrcMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getSrcMapAttrStrName()));
   }
 
   /// Returns the source memref affine map indices for this DMA operation.
@@ -156,7 +157,8 @@ class AffineDmaStartOp
   /// Returns the affine map used to access the destination memref.
   AffineMap getDstMap() { return getDstMapAttr().getValue(); }
   AffineMapAttr getDstMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getDstMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getDstMapAttrStrName()));
   }
 
   /// Returns the destination memref indices for this DMA operation.
@@ -185,7 +187,8 @@ class AffineDmaStartOp
   /// Returns the affine map used to access the tag memref.
   AffineMap getTagMap() { return getTagMapAttr().getValue(); }
   AffineMapAttr getTagMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getTagMapAttrStrName()));
   }
 
   /// Returns the tag memref indices for this DMA operation.
@@ -307,7 +310,8 @@ class AffineDmaWaitOp
   /// Returns the affine map used to access the tag memref.
   AffineMap getTagMap() { return getTagMapAttr().getValue(); }
   AffineMapAttr getTagMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getTagMapAttrStrName()));
   }
 
   /// Returns the tag memref index for this DMA operation.
@@ -465,6 +469,23 @@ AffineForOp getForInductionVarOwner(Value val);
 /// AffineParallelOp.
 AffineParallelOp getAffineParallelInductionVarOwner(Value val);
 
+/// Helper to replace uses of loop carried values (iter_args) and loop
+/// yield values while promoting single iteration affine.for ops.
+void replaceIterArgsAndYieldResults(AffineForOp forOp);
+
+/// Returns the trip count of the loop as an affine expression if the latter is
+/// expressible as an affine expression, and nullptr otherwise. The trip count
+/// expression is simplified before returning. This method only utilizes map
+/// composition to construct lower and upper bounds before computing the trip
+/// count expressions.
+void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *tripCountMap,
+                                SmallVectorImpl<Value> *tripCountOperands);
+
+/// Returns the trip count of the loop if it's a constant, std::nullopt
+/// otherwise. This uses affine expression analysis and is able to determine
+/// constant trip count in non-trivial cases.
+std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);
+
 /// Extracts the induction variables from a list of AffineForOps and places them
 /// in the output argument `ivs`.
 void extractForInductionVars(ArrayRef<AffineForOp> forInsts,
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index f9578cf37d5d768..b4ea6122ed4c0e0 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -121,7 +121,7 @@ def AffineForOp : Affine_Op<"for",
      ImplicitAffineTerminator, ConditionallySpeculatable,
      RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
      ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
-      "getSingleUpperBound", "getYieldedValuesMutable",
+      "getSingleUpperBound", "getYieldedValuesMutable", "promoteIfSingleIteration",
       "replaceWithAdditionalYields"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
      ["getEntrySuccessorOperands"]>]> {
diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
index 723a262f24acc51..1e3b3bffea7b838 100644
--- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
@@ -83,10 +83,6 @@ LogicalResult loopUnrollJamByFactor(AffineForOp forOp,
 LogicalResult loopUnrollJamUpToFactor(AffineForOp forOp,
                                       uint64_t unrollJamFactor);
 
-/// Promotes the loop body of a AffineForOp to its containing block if the loop
-/// was known to have a single iteration.
-LogicalResult promoteIfSingleIteration(AffineForOp forOp);
-
 /// Promotes all single iteration AffineForOp's in the Function, i.e., moves
 /// their body into the containing Block.
 void promoteSingleIterationLoops(func::FuncOp f);
diff --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
index e645afe7cd3e8fa..24f119464b416a7 100644
--- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
@@ -12,7 +12,6 @@
 
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 
-#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
 #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
@@ -20,9 +19,9 @@
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Support/MathExtras.h"
 
-#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallString.h"
+
 #include <numeric>
 #include <optional>
 #include <type_traits>
@@ -30,83 +29,6 @@
 using namespace mlir;
 using namespace mlir::affine;
 
-/// Returns the trip count of the loop as an affine expression if the latter is
-/// expressible as an affine expression, and nullptr otherwise. The trip count
-/// expression is simplified before returning. This method only utilizes map
-/// composition to construct lower and upper bounds before computing the trip
-/// count expressions.
-void mlir::affine::getTripCountMapAndOperands(
-    AffineForOp forOp, AffineMap *tripCountMap,
-    SmallVectorImpl<Value> *tripCountOperands) {
-  MLIRContext *context = forOp.getContext();
-  int64_t step = forOp.getStepAsInt();
-  int64_t loopSpan;
-  if (forOp.hasConstantBounds()) {
-    int64_t lb = forOp.getConstantLowerBound();
-    int64_t ub = forOp.getConstantUpperBound();
-    loopSpan = ub - lb;
-    if (loopSpan < 0)
-      loopSpan = 0;
-    *tripCountMap = AffineMap::getConstantMap(ceilDiv(loopSpan, step), context);
-    tripCountOperands->clear();
-    return;
-  }
-  auto lbMap = forOp.getLowerBoundMap();
-  auto ubMap = forOp.getUpperBoundMap();
-  if (lbMap.getNumResults() != 1) {
-    *tripCountMap = AffineMap();
-    return;
-  }
-
-  // Difference of each upper bound expression from the single lower bound
-  // expression (divided by the step) provides the expressions for the trip
-  // count map.
-  AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands());
-
-  SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
-                                         lbMap.getResult(0));
-  auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
-                                   lbSplatExpr, context);
-  AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());
-
-  AffineValueMap tripCountValueMap;
-  AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap);
-  for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
-    tripCountValueMap.setResult(i,
-                                tripCountValueMap.getResult(i).ceilDiv(step));
-
-  *tripCountMap = tripCountValueMap.getAffineMap();
-  tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
-                            tripCountValueMap.getOperands().end());
-}
-
-/// Returns the trip count of the loop if it's a constant, std::nullopt
-/// otherwise. This method uses affine expression analysis (in turn using
-/// getTripCount) and is able to determine constant trip count in non-trivial
-/// cases.
-std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
-  SmallVector<Value, 4> operands;
-  AffineMap map;
-  getTripCountMapAndOperands(forOp, &map, &operands);
-
-  if (!map)
-    return std::nullopt;
-
-  // Take the min if all trip counts are constant.
-  std::optional<uint64_t> tripCount;
-  for (auto resultExpr : map.getResults()) {
-    if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
-      if (tripCount.has_value())
-        tripCount =
-            std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
-      else
-        tripCount = constExpr.getValue();
-    } else
-      return std::nullopt;
-  }
-  return tripCount;
-}
-
 /// Returns the greatest known integral divisor of the trip count. Affine
 /// expression analysis is used (indirectly through getTripCount), and
 /// this method is thus able to determine non-trivial divisors.
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 05496e70716a2a1..8716d7a3525b526 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/AffineExprVisitor.h"
@@ -23,6 +24,7 @@
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+
 #include <numeric>
 #include <optional>
 
@@ -2440,6 +2442,69 @@ std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
   return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
 }
 
+void mlir::affine::replaceIterArgsAndYieldResults(AffineForOp forOp) {
+  // Replace uses of iter arguments with iter operands (initial values).
+  OperandRange iterOperands = forOp.getInits();
+  MutableArrayRef<BlockArgument> iterArgs = forOp.getRegionIterArgs();
+  for (auto [operand, arg] : llvm::zip(iterOperands, iterArgs))
+    arg.replaceAllUsesWith(operand);
+
+  // Replace uses of loop results with the values yielded by the loop.
+  ResultRange outerResults = forOp.getResults();
+  OperandRange innerResults = forOp.getBody()->getTerminator()->getOperands();
+  for (auto [outer, inner] : llvm::zip(outerResults, innerResults))
+    outer.replaceAllUsesWith(inner);
+}
+
+LogicalResult AffineForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
+  auto forOp = cast<AffineForOp>(getOperation());
+  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  if (!tripCount || *tripCount != 1)
+    return failure();
+
+  // TODO: extend this for arbitrary affine bounds.
+  if (forOp.getLowerBoundMap().getNumResults() != 1)
+    return failure();
+
+  // Replaces all IV uses to its single iteration value.
+  BlockArgument iv = forOp.getInductionVar();
+  if (!iv.use_empty()) {
+    if (forOp.hasConstantLowerBound()) {
+      Operation *parentOp = forOp.getOperation();
+      while (isa<AffineForOp>(parentOp->getParentOp()))
+        parentOp = parentOp->getParentOp();
+      Block *parentBlock = parentOp->getBlock();
+      OpBuilder topBuilder(parentBlock, parentBlock->begin());
+      auto constOp = topBuilder.create<arith::ConstantIndexOp>(
+          forOp.getLoc(), forOp.getConstantLowerBound());
+      iv.replaceAllUsesWith(constOp);
+    } else {
+      OperandRange lbOperands = forOp.getLowerBoundOperands();
+      AffineMap lbMap = forOp.getLowerBoundMap();
+      OpBuilder builder(forOp);
+      if (lbMap == builder.getDimIdentityMap()) {
+        // No need of generating an affine.apply.
+        iv.replaceAllUsesWith(lbOperands[0]);
+      } else {
+        auto affineApplyOp =
+            builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
+        iv.replaceAllUsesWith(affineApplyOp);
+      }
+    }
+  }
+
+  replaceIterArgsAndYieldResults(forOp);
+
+  // Move the loop body operations, except for its terminator, to the loop's
+  // containing block.
+  forOp.getBody()->back().erase();
+  Block *parentBlock = forOp->getBlock();
+  parentBlock->getOperations().splice(Block::iterator(forOp),
+                                      forOp.getBody()->getOperations());
+  forOp.erase();
+  return success();
+}
+
 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
     RewriterBase &rewriter, ValueRange newInitOperands,
     bool replaceInitOperandUsesInLoop,
@@ -2538,6 +2603,79 @@ AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) {
   return nullptr;
 }
 
+/// Returns the trip count of the loop as an affine expression if the latter is
+/// expressible as an affine expression, and nullptr otherwise. The trip count
+/// expression is simplified before returning. This method only utilizes map
+/// composition to construct lower and upper bounds before computing the trip
+/// count expressions.
+void mlir::affine::getTripCountMapAndOperands(
+    AffineForOp forOp, AffineMap *tripCountMap,
+    SmallVectorImpl<Value> *tripCountOperands) {
+  MLIRContext *context = forOp.getContext();
+  int64_t step = forOp.getStepAsInt();
+  int64_t loopSpan;
+  if (forOp.hasConstantBounds()) {
+    int64_t lb = forOp.getConstantLowerBound();
+    int64_t ub = forOp.getConstantUpperBound();
+    loopSpan = ub - lb;
+    if (loopSpan < 0)
+      loopSpan = 0;
+    *tripCountMap = AffineMap::getConstantMap(ceilDiv(loopSpan, step), context);
+    tripCountOperands->clear();
+    return;
+  }
+  auto lbMap = forOp.getLowerBoundMap();
+  auto ubMap = forOp.getUpperBoundMap();
+  if (lbMap.getNumResults() != 1) {
+    *tripCountMap = AffineMap();
+    return;
+  }
+
+  // Difference of each upper bound expression from the single lower bound
+  // expression (divided by the step) provides the expressions for the trip
+  // count map.
+  AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands());
+
+  SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
+                                         lbMap.getResult(0));
+  auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
+                                   lbSplatExpr, context);
+  AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());
+
+  AffineValueMap tripCountValueMap;
+  AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap);
+  for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
+    tripCountValueMap.setResult(i,
+                                tripCountValueMap.getResult(i).ceilDiv(step));
+
+  *tripCountMap = tripCountValueMap.getAffineMap();
+  tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
+                            tripCountValueMap.getOperands().end());
+}
+
+std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
+  SmallVector<Value, 4> operands;
+  AffineMap map;
+  getTripCountMapAndOperands(forOp, &map, &operands);
+
+  if (!map)
+    return std::nullopt;
+
+  // Take the min if all trip counts are constant.
+  std::optional<uint64_t> tripCount;
+  for (auto resultExpr : map.getResults()) {
+    if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
+      if (tripCount.has_value())
+        tripCount =
+            std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
+      else
+        tripCount = constExpr.getValue();
+    } else
+      return std::nullopt;
+  }
+  return tripCount;
+}
+
 /// Extracts the induction variables from a list of AffineForOps and returns
 /// them.
 void mlir::affine::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
@@ -2905,8 +3043,7 @@ static void composeSetAndOperands(IntegerSet &set,
 }
 
 /// Canonicalize an affine if op's conditional (integer set + operands).
-LogicalResult AffineIfOp::fold(FoldAdaptor,
-                               SmallVectorImpl<OpFoldResult> &) {
+LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
   auto set = getIntegerSet();
   SmallVector<Value, 4> operands(getOperands());
   composeSetAndOperands(set, operands);
@@ -2997,11 +3134,11 @@ static LogicalResult
 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
                        Operation::operand_range mapOperands,
                        MemRefType memrefType, unsigned numIndexOperands) {
-    AffineMap map = mapAttr.getValue();
-    if (map.getNumResults() != memrefType.getRank())
-      return op->emitOpError("affine map num results must equal memref rank");
-    if (map.getNumInputs() != numIndexOperands)
-      return op->emitOpError("expects as many subscripts as affine map inputs");
+  AffineMap map = mapAttr.getValue();
+  if (map.getNumResults() != memrefType.getRank())
+    return op->emitOpError("affine map num results must equal memref rank");
+  if (map.getNumInputs() != numIndexOperands)
+    return op->emitOpError("expects as many subscripts as affine map inputs");
 
   Region *scope = getAffineScope(op);
   for (auto idx : mapOperands) {
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 331b0f1b2c2b1c6..31b90a60472c1f1 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -219,13 +219,14 @@ void AffineDataCopyGeneration::runOnOperation() {
 
   // Promote any single iteration loops in the copy nests and collect
   // load/stores to simplify.
+  IRRewriter rewriter(f.getContext());
   SmallVector<Operation *, 4> copyOps;
   for (Operation *nest : copyNests)
     // With a post order walk, the erasure of loops does not affect
     // continuation of the walk or the collection of load/store ops.
     nest->walk([&](Operation *op) {
       if (auto forOp = dyn_cast<AffineForOp>(op))
-        (void)promoteIfSingleIteration(forOp);
+        (void)forOp.promoteIfSingleIteration(rewriter);
       else if (isa<AffineLoadOp, AffineStoreOp>(op))
         copyOps.push_back(op);
     });
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index 5053b08ee0834cd..d11e77544e24ea5 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -457,6 +457,7 @@ void m...
[truncated]

BlockArgument iv = forOp.getInductionVar();
if (!iv.use_empty()) {
if (forOp.hasConstantLowerBound()) {
Operation *parentOp = forOp.getOperation();
Copy link
Contributor

@srcarroll srcarroll Nov 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't you just do rewriter.setInsertPoint(forOp) and build the constant op with rewriter? this will write the ops right before the forOp. You can also use an insertion guard so that the rewriter insertion point is restored after leaving this scope

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nevermind. i see what you are doing below now and want the first non-for parent. but my reuse rewriter comment should still be applied

} else {
OperandRange lbOperands = forOp.getLowerBoundOperands();
AffineMap lbMap = forOp.getLowerBoundMap();
OpBuilder builder(forOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to comment above; reuse rewriter and set insertion point

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in fact i now realize the rewriter arg wasn't used in this function at all, so general comment is to use this rather than creating new builders, unless there's a reason i dont see to do it this way

Copy link
Contributor

@srcarroll srcarroll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just commenting on promoteIfSingleIteration for now, but will leave other comments later

} else {
OperandRange lbOperands = forOp.getLowerBoundOperands();
AffineMap lbMap = forOp.getLowerBoundMap();
OpBuilder builder(forOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in fact i now realize the rewriter arg wasn't used in this function at all, so general comment is to use this rather than creating new builders, unless there's a reason i dont see to do it this way

@srcarroll
Copy link
Contributor

srcarroll commented Nov 19, 2023

i'm no authority here but i feel like just leaving the constant in the parent of the affine.for that's being removed is the correct (semantically equivalent) way to do this. after all, canonicalize would lift the constant where it should go as dictated by all the parents' canonicalizers.

But anyway, I agree with your change more than leaving it with the func, but could be simpler is all I'm suggesting.

@srcarroll
Copy link
Contributor

srcarroll commented Nov 19, 2023

Another point against leaving it as inserting in the parent func.func (or at least as it is now pre-change) is that it's too restrictive and in fact buggy. first of all this wouldn't work on affine.for inside any other funclike op. second, the original implementation doesn't have any null checks so would crash anyway.

/// otherwise. This uses affine expression analysis and is able to determine
/// constant trip count in non-trivial cases.
std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although it would be orthogonal to this PR, it may be worth thinking about moving these to AffineForOp methods since they are being moved to the dialect anyway. Just a thought, no action necessary.

Block *parentBlock = forOp->getBlock();
parentBlock->getOperations().splice(Block::iterator(forOp),
forOp.getBody()->getOperations());
forOp.erase();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it's safer to use rewriter.eraseOp(forOp) and similar on line 2500 above. This could break if used in a pattern rewriter for a class. It has happened to me before.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants