Skip to content

Commit aa6be2f

Browse files
authored
[mlir][affine] implement promoteIfSingleIteration for AffineForOp (#72547)
1 parent 99387e3 commit aa6be2f

File tree

11 files changed

+190
-185
lines changed

11 files changed

+190
-185
lines changed

mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mlir/Support/LLVM.h"
1717
#include "llvm/ADT/ArrayRef.h"
18+
1819
#include <optional>
1920

2021
namespace mlir {
@@ -29,20 +30,6 @@ namespace affine {
2930
class AffineForOp;
3031
class NestedPattern;
3132

32-
/// Returns the trip count of the loop as an affine map with its corresponding
33-
/// operands if the latter is expressible as an affine expression, and nullptr
34-
/// otherwise. This method always succeeds as long as the lower bound is not a
35-
/// multi-result map. The trip count expression is simplified before returning.
36-
/// This method only utilizes map composition to construct lower and upper
37-
/// bounds before computing the trip count expressions
38-
void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *map,
39-
SmallVectorImpl<Value> *operands);
40-
41-
/// Returns the trip count of the loop if it's a constant, std::nullopt
42-
/// otherwise. This uses affine expression analysis and is able to determine
43-
/// constant trip count in non-trivial cases.
44-
std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);
45-
4633
/// Returns the greatest known integral divisor of the trip count. Affine
4734
/// expression analysis is used (indirectly through getTripCount), and
4835
/// this method is thus able to determine non-trivial divisors.

mlir/include/mlir/Dialect/Affine/IR/AffineOps.h

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ class AffineDmaStartOp
117117
/// Returns the affine map used to access the source memref.
118118
AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
119119
AffineMapAttr getSrcMapAttr() {
120-
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getSrcMapAttrStrName()));
120+
return cast<AffineMapAttr>(
121+
*(*this)->getInherentAttr(getSrcMapAttrStrName()));
121122
}
122123

123124
/// Returns the source memref affine map indices for this DMA operation.
@@ -156,7 +157,8 @@ class AffineDmaStartOp
156157
/// Returns the affine map used to access the destination memref.
157158
AffineMap getDstMap() { return getDstMapAttr().getValue(); }
158159
AffineMapAttr getDstMapAttr() {
159-
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getDstMapAttrStrName()));
160+
return cast<AffineMapAttr>(
161+
*(*this)->getInherentAttr(getDstMapAttrStrName()));
160162
}
161163

162164
/// Returns the destination memref indices for this DMA operation.
@@ -185,7 +187,8 @@ class AffineDmaStartOp
185187
/// Returns the affine map used to access the tag memref.
186188
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
187189
AffineMapAttr getTagMapAttr() {
188-
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
190+
return cast<AffineMapAttr>(
191+
*(*this)->getInherentAttr(getTagMapAttrStrName()));
189192
}
190193

191194
/// Returns the tag memref indices for this DMA operation.
@@ -307,7 +310,8 @@ class AffineDmaWaitOp
307310
/// Returns the affine map used to access the tag memref.
308311
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
309312
AffineMapAttr getTagMapAttr() {
310-
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
313+
return cast<AffineMapAttr>(
314+
*(*this)->getInherentAttr(getTagMapAttrStrName()));
311315
}
312316

313317
/// Returns the tag memref index for this DMA operation.
@@ -465,6 +469,23 @@ AffineForOp getForInductionVarOwner(Value val);
465469
/// AffineParallelOp.
466470
AffineParallelOp getAffineParallelInductionVarOwner(Value val);
467471

472+
/// Helper to replace uses of loop carried values (iter_args) and loop
473+
/// yield values while promoting single iteration affine.for ops.
474+
void replaceIterArgsAndYieldResults(AffineForOp forOp);
475+
476+
/// Returns the trip count of the loop as an affine expression if the latter is
477+
/// expressible as an affine expression, and nullptr otherwise. The trip count
478+
/// expression is simplified before returning. This method only utilizes map
479+
/// composition to construct lower and upper bounds before computing the trip
480+
/// count expressions.
481+
void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *tripCountMap,
482+
SmallVectorImpl<Value> *tripCountOperands);
483+
484+
/// Returns the trip count of the loop if it's a constant, std::nullopt
485+
/// otherwise. This uses affine expression analysis and is able to determine
486+
/// constant trip count in non-trivial cases.
487+
std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);
488+
468489
/// Extracts the induction variables from a list of AffineForOps and places them
469490
/// in the output argument `ivs`.
470491
void extractForInductionVars(ArrayRef<AffineForOp> forInsts,

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def AffineForOp : Affine_Op<"for",
121121
ImplicitAffineTerminator, ConditionallySpeculatable,
122122
RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
123123
["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
124-
"getSingleUpperBound", "getYieldedValuesMutable",
124+
"getSingleUpperBound", "getYieldedValuesMutable", "promoteIfSingleIteration",
125125
"replaceWithAdditionalYields"]>,
126126
DeclareOpInterfaceMethods<RegionBranchOpInterface,
127127
["getEntrySuccessorOperands"]>]> {

mlir/include/mlir/Dialect/Affine/LoopUtils.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,6 @@ LogicalResult loopUnrollJamByFactor(AffineForOp forOp,
8383
LogicalResult loopUnrollJamUpToFactor(AffineForOp forOp,
8484
uint64_t unrollJamFactor);
8585

86-
/// Promotes the loop body of a AffineForOp to its containing block if the loop
87-
/// was known to have a single iteration.
88-
LogicalResult promoteIfSingleIteration(AffineForOp forOp);
89-
9086
/// Promotes all single iteration AffineForOp's in the Function, i.e., moves
9187
/// their body into the containing Block.
9288
void promoteSingleIterationLoops(func::FuncOp f);

mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp

Lines changed: 1 addition & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -12,101 +12,23 @@
1212

1313
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1414

15-
#include "mlir/Analysis/SliceAnalysis.h"
1615
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
1716
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
1817
#include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
1918
#include "mlir/Dialect/Affine/IR/AffineOps.h"
2019
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
2120
#include "mlir/Support/MathExtras.h"
2221

23-
#include "llvm/ADT/DenseSet.h"
2422
#include "llvm/ADT/SmallPtrSet.h"
2523
#include "llvm/ADT/SmallString.h"
24+
2625
#include <numeric>
2726
#include <optional>
2827
#include <type_traits>
2928

3029
using namespace mlir;
3130
using namespace mlir::affine;
3231

33-
/// Returns the trip count of the loop as an affine expression if the latter is
34-
/// expressible as an affine expression, and nullptr otherwise. The trip count
35-
/// expression is simplified before returning. This method only utilizes map
36-
/// composition to construct lower and upper bounds before computing the trip
37-
/// count expressions.
38-
void mlir::affine::getTripCountMapAndOperands(
39-
AffineForOp forOp, AffineMap *tripCountMap,
40-
SmallVectorImpl<Value> *tripCountOperands) {
41-
MLIRContext *context = forOp.getContext();
42-
int64_t step = forOp.getStepAsInt();
43-
int64_t loopSpan;
44-
if (forOp.hasConstantBounds()) {
45-
int64_t lb = forOp.getConstantLowerBound();
46-
int64_t ub = forOp.getConstantUpperBound();
47-
loopSpan = ub - lb;
48-
if (loopSpan < 0)
49-
loopSpan = 0;
50-
*tripCountMap = AffineMap::getConstantMap(ceilDiv(loopSpan, step), context);
51-
tripCountOperands->clear();
52-
return;
53-
}
54-
auto lbMap = forOp.getLowerBoundMap();
55-
auto ubMap = forOp.getUpperBoundMap();
56-
if (lbMap.getNumResults() != 1) {
57-
*tripCountMap = AffineMap();
58-
return;
59-
}
60-
61-
// Difference of each upper bound expression from the single lower bound
62-
// expression (divided by the step) provides the expressions for the trip
63-
// count map.
64-
AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands());
65-
66-
SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
67-
lbMap.getResult(0));
68-
auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
69-
lbSplatExpr, context);
70-
AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());
71-
72-
AffineValueMap tripCountValueMap;
73-
AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap);
74-
for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
75-
tripCountValueMap.setResult(i,
76-
tripCountValueMap.getResult(i).ceilDiv(step));
77-
78-
*tripCountMap = tripCountValueMap.getAffineMap();
79-
tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
80-
tripCountValueMap.getOperands().end());
81-
}
82-
83-
/// Returns the trip count of the loop if it's a constant, std::nullopt
84-
/// otherwise. This method uses affine expression analysis (in turn using
85-
/// getTripCount) and is able to determine constant trip count in non-trivial
86-
/// cases.
87-
std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
88-
SmallVector<Value, 4> operands;
89-
AffineMap map;
90-
getTripCountMapAndOperands(forOp, &map, &operands);
91-
92-
if (!map)
93-
return std::nullopt;
94-
95-
// Take the min if all trip counts are constant.
96-
std::optional<uint64_t> tripCount;
97-
for (auto resultExpr : map.getResults()) {
98-
if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
99-
if (tripCount.has_value())
100-
tripCount =
101-
std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
102-
else
103-
tripCount = constExpr.getValue();
104-
} else
105-
return std::nullopt;
106-
}
107-
return tripCount;
108-
}
109-
11032
/// Returns the greatest known integral divisor of the trip count. Affine
11133
/// expression analysis is used (indirectly through getTripCount), and
11234
/// this method is thus able to determine non-trivial divisors.

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 140 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10+
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1011
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
1113
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1214
#include "mlir/Dialect/UB/IR/UBOps.h"
1315
#include "mlir/IR/AffineExprVisitor.h"
@@ -2448,6 +2450,65 @@ std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
24482450
return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
24492451
}
24502452

2453+
void mlir::affine::replaceIterArgsAndYieldResults(AffineForOp forOp) {
2454+
// Replace uses of iter arguments with iter operands (initial values).
2455+
OperandRange iterOperands = forOp.getInits();
2456+
MutableArrayRef<BlockArgument> iterArgs = forOp.getRegionIterArgs();
2457+
for (auto [operand, arg] : llvm::zip(iterOperands, iterArgs))
2458+
arg.replaceAllUsesWith(operand);
2459+
2460+
// Replace uses of loop results with the values yielded by the loop.
2461+
ResultRange outerResults = forOp.getResults();
2462+
OperandRange innerResults = forOp.getBody()->getTerminator()->getOperands();
2463+
for (auto [outer, inner] : llvm::zip(outerResults, innerResults))
2464+
outer.replaceAllUsesWith(inner);
2465+
}
2466+
2467+
LogicalResult AffineForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
2468+
auto forOp = cast<AffineForOp>(getOperation());
2469+
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
2470+
if (!tripCount || *tripCount != 1)
2471+
return failure();
2472+
2473+
// TODO: extend this for arbitrary affine bounds.
2474+
if (forOp.getLowerBoundMap().getNumResults() != 1)
2475+
return failure();
2476+
2477+
// Replaces all IV uses to its single iteration value.
2478+
BlockArgument iv = forOp.getInductionVar();
2479+
Block *parentBlock = forOp->getBlock();
2480+
if (!iv.use_empty()) {
2481+
if (forOp.hasConstantLowerBound()) {
2482+
OpBuilder topBuilder(forOp->getParentOfType<func::FuncOp>().getBody());
2483+
auto constOp = topBuilder.create<arith::ConstantIndexOp>(
2484+
forOp.getLoc(), forOp.getConstantLowerBound());
2485+
iv.replaceAllUsesWith(constOp);
2486+
} else {
2487+
OperandRange lbOperands = forOp.getLowerBoundOperands();
2488+
AffineMap lbMap = forOp.getLowerBoundMap();
2489+
OpBuilder builder(forOp);
2490+
if (lbMap == builder.getDimIdentityMap()) {
2491+
// No need of generating an affine.apply.
2492+
iv.replaceAllUsesWith(lbOperands[0]);
2493+
} else {
2494+
auto affineApplyOp =
2495+
builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
2496+
iv.replaceAllUsesWith(affineApplyOp);
2497+
}
2498+
}
2499+
}
2500+
2501+
replaceIterArgsAndYieldResults(forOp);
2502+
2503+
// Move the loop body operations, except for its terminator, to the loop's
2504+
// containing block.
2505+
forOp.getBody()->back().erase();
2506+
parentBlock->getOperations().splice(Block::iterator(forOp),
2507+
forOp.getBody()->getOperations());
2508+
forOp.erase();
2509+
return success();
2510+
}
2511+
24512512
FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
24522513
RewriterBase &rewriter, ValueRange newInitOperands,
24532514
bool replaceInitOperandUsesInLoop,
@@ -2546,6 +2607,79 @@ AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) {
25462607
return nullptr;
25472608
}
25482609

2610+
/// Returns the trip count of the loop as an affine expression if the latter is
2611+
/// expressible as an affine expression, and nullptr otherwise. The trip count
2612+
/// expression is simplified before returning. This method only utilizes map
2613+
/// composition to construct lower and upper bounds before computing the trip
2614+
/// count expressions.
2615+
void mlir::affine::getTripCountMapAndOperands(
2616+
AffineForOp forOp, AffineMap *tripCountMap,
2617+
SmallVectorImpl<Value> *tripCountOperands) {
2618+
MLIRContext *context = forOp.getContext();
2619+
int64_t step = forOp.getStepAsInt();
2620+
int64_t loopSpan;
2621+
if (forOp.hasConstantBounds()) {
2622+
int64_t lb = forOp.getConstantLowerBound();
2623+
int64_t ub = forOp.getConstantUpperBound();
2624+
loopSpan = ub - lb;
2625+
if (loopSpan < 0)
2626+
loopSpan = 0;
2627+
*tripCountMap = AffineMap::getConstantMap(ceilDiv(loopSpan, step), context);
2628+
tripCountOperands->clear();
2629+
return;
2630+
}
2631+
auto lbMap = forOp.getLowerBoundMap();
2632+
auto ubMap = forOp.getUpperBoundMap();
2633+
if (lbMap.getNumResults() != 1) {
2634+
*tripCountMap = AffineMap();
2635+
return;
2636+
}
2637+
2638+
// Difference of each upper bound expression from the single lower bound
2639+
// expression (divided by the step) provides the expressions for the trip
2640+
// count map.
2641+
AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands());
2642+
2643+
SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
2644+
lbMap.getResult(0));
2645+
auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
2646+
lbSplatExpr, context);
2647+
AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());
2648+
2649+
AffineValueMap tripCountValueMap;
2650+
AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap);
2651+
for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
2652+
tripCountValueMap.setResult(i,
2653+
tripCountValueMap.getResult(i).ceilDiv(step));
2654+
2655+
*tripCountMap = tripCountValueMap.getAffineMap();
2656+
tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
2657+
tripCountValueMap.getOperands().end());
2658+
}
2659+
2660+
std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
2661+
SmallVector<Value, 4> operands;
2662+
AffineMap map;
2663+
getTripCountMapAndOperands(forOp, &map, &operands);
2664+
2665+
if (!map)
2666+
return std::nullopt;
2667+
2668+
// Take the min if all trip counts are constant.
2669+
std::optional<uint64_t> tripCount;
2670+
for (auto resultExpr : map.getResults()) {
2671+
if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
2672+
if (tripCount.has_value())
2673+
tripCount =
2674+
std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
2675+
else
2676+
tripCount = constExpr.getValue();
2677+
} else
2678+
return std::nullopt;
2679+
}
2680+
return tripCount;
2681+
}
2682+
25492683
/// Extracts the induction variables from a list of AffineForOps and returns
25502684
/// them.
25512685
void mlir::affine::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
@@ -2913,8 +3047,7 @@ static void composeSetAndOperands(IntegerSet &set,
29133047
}
29143048

29153049
/// Canonicalize an affine if op's conditional (integer set + operands).
2916-
LogicalResult AffineIfOp::fold(FoldAdaptor,
2917-
SmallVectorImpl<OpFoldResult> &) {
3050+
LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
29183051
auto set = getIntegerSet();
29193052
SmallVector<Value, 4> operands(getOperands());
29203053
composeSetAndOperands(set, operands);
@@ -3005,11 +3138,11 @@ static LogicalResult
30053138
verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
30063139
Operation::operand_range mapOperands,
30073140
MemRefType memrefType, unsigned numIndexOperands) {
3008-
AffineMap map = mapAttr.getValue();
3009-
if (map.getNumResults() != memrefType.getRank())
3010-
return op->emitOpError("affine map num results must equal memref rank");
3011-
if (map.getNumInputs() != numIndexOperands)
3012-
return op->emitOpError("expects as many subscripts as affine map inputs");
3141+
AffineMap map = mapAttr.getValue();
3142+
if (map.getNumResults() != memrefType.getRank())
3143+
return op->emitOpError("affine map num results must equal memref rank");
3144+
if (map.getNumInputs() != numIndexOperands)
3145+
return op->emitOpError("expects as many subscripts as affine map inputs");
30133146

30143147
Region *scope = getAffineScope(op);
30153148
for (auto idx : mapOperands) {

0 commit comments

Comments
 (0)