Skip to content

Commit f8e5e5b

Browse files
committed
[mlir][affine] affineforop promote single
1 parent 3f743fd commit f8e5e5b

File tree

10 files changed

+96
-88
lines changed

10 files changed

+96
-88
lines changed

mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *map,
4343
/// constant trip count in non-trivial cases.
4444
std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);
4545

46+
/// Helper to replace uses of loop carried values (iter_args) and loop
47+
/// yield values while promoting single iteration affine.for ops.
48+
void replaceIterArgsAndYieldResults(AffineForOp forOp);
49+
4650
/// Returns the greatest known integral divisor of the trip count. Affine
4751
/// expression analysis is used (indirectly through getTripCount), and
4852
/// this method is thus able to determine non-trivial divisors.

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def AffineForOp : Affine_Op<"for",
121121
ImplicitAffineTerminator, ConditionallySpeculatable,
122122
RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
123123
["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
124-
"getSingleUpperBound", "getYieldedValuesMutable",
124+
"getSingleUpperBound", "getYieldedValuesMutable", "promoteIfSingleIteration",
125125
"replaceWithAdditionalYields"]>,
126126
DeclareOpInterfaceMethods<RegionBranchOpInterface,
127127
["getEntrySuccessorOperands"]>]> {

mlir/include/mlir/Dialect/Affine/LoopUtils.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,6 @@ LogicalResult loopUnrollJamByFactor(AffineForOp forOp,
8383
LogicalResult loopUnrollJamUpToFactor(AffineForOp forOp,
8484
uint64_t unrollJamFactor);
8585

86-
/// Promotes the loop body of a AffineForOp to its containing block if the loop
87-
/// was known to have a single iteration.
88-
LogicalResult promoteIfSingleIteration(AffineForOp forOp);
89-
9086
/// Promotes all single iteration AffineForOp's in the Function, i.e., moves
9187
/// their body into the containing Block.
9288
void promoteSingleIterationLoops(func::FuncOp f);

mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,20 @@ std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
107107
return tripCount;
108108
}
109109

110+
void mlir::affine::replaceIterArgsAndYieldResults(AffineForOp forOp) {
111+
// Replace uses of iter arguments with iter operands (initial values).
112+
auto iterOperands = forOp.getInits();
113+
auto iterArgs = forOp.getRegionIterArgs();
114+
for (auto e : llvm::zip(iterOperands, iterArgs))
115+
std::get<1>(e).replaceAllUsesWith(std::get<0>(e));
116+
117+
// Replace uses of loop results with the values yielded by the loop.
118+
auto outerResults = forOp.getResults();
119+
auto innerResults = forOp.getBody()->getTerminator()->getOperands();
120+
for (auto e : llvm::zip(outerResults, innerResults))
121+
std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
122+
}
123+
110124
/// Returns the greatest known integral divisor of the trip count. Affine
111125
/// expression analysis is used (indirectly through getTripCount), and
112126
/// this method is thus able to determine non-trivial divisors.

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10+
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1011
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
1113
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1214
#include "mlir/IR/AffineExprVisitor.h"
1315
#include "mlir/IR/IRMapping.h"
@@ -2440,6 +2442,53 @@ std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
24402442
return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
24412443
}
24422444

2445+
/// Promotes the loop body of a forOp to its containing block if the forOp
2446+
/// was known to have a single iteration.
2447+
LogicalResult AffineForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
2448+
auto forOp = cast<AffineForOp>(getOperation());
2449+
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
2450+
if (!tripCount || *tripCount != 1)
2451+
return failure();
2452+
2453+
// TODO: extend this for arbitrary affine bounds.
2454+
if (forOp.getLowerBoundMap().getNumResults() != 1)
2455+
return failure();
2456+
2457+
// Replaces all IV uses to its single iteration value.
2458+
auto iv = forOp.getInductionVar();
2459+
auto *parentBlock = forOp->getBlock();
2460+
if (!iv.use_empty()) {
2461+
if (forOp.hasConstantLowerBound()) {
2462+
OpBuilder topBuilder(forOp->getParentOfType<func::FuncOp>().getBody());
2463+
auto constOp = topBuilder.create<arith::ConstantIndexOp>(
2464+
forOp.getLoc(), forOp.getConstantLowerBound());
2465+
iv.replaceAllUsesWith(constOp);
2466+
} else {
2467+
auto lbOperands = forOp.getLowerBoundOperands();
2468+
auto lbMap = forOp.getLowerBoundMap();
2469+
OpBuilder builder(forOp);
2470+
if (lbMap == builder.getDimIdentityMap()) {
2471+
// No need of generating an affine.apply.
2472+
iv.replaceAllUsesWith(lbOperands[0]);
2473+
} else {
2474+
auto affineApplyOp =
2475+
builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
2476+
iv.replaceAllUsesWith(affineApplyOp);
2477+
}
2478+
}
2479+
}
2480+
2481+
replaceIterArgsAndYieldResults(forOp);
2482+
2483+
// Move the loop body operations, except for its terminator, to the loop's
2484+
// containing block.
2485+
forOp.getBody()->back().erase();
2486+
parentBlock->getOperations().splice(Block::iterator(forOp),
2487+
forOp.getBody()->getOperations());
2488+
forOp.erase();
2489+
return success();
2490+
}
2491+
24432492
FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
24442493
RewriterBase &rewriter, ValueRange newInitOperands,
24452494
bool replaceInitOperandUsesInLoop,
@@ -2905,8 +2954,7 @@ static void composeSetAndOperands(IntegerSet &set,
29052954
}
29062955

29072956
/// Canonicalize an affine if op's conditional (integer set + operands).
2908-
LogicalResult AffineIfOp::fold(FoldAdaptor,
2909-
SmallVectorImpl<OpFoldResult> &) {
2957+
LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
29102958
auto set = getIntegerSet();
29112959
SmallVector<Value, 4> operands(getOperands());
29122960
composeSetAndOperands(set, operands);
@@ -2997,11 +3045,11 @@ static LogicalResult
29973045
verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
29983046
Operation::operand_range mapOperands,
29993047
MemRefType memrefType, unsigned numIndexOperands) {
3000-
AffineMap map = mapAttr.getValue();
3001-
if (map.getNumResults() != memrefType.getRank())
3002-
return op->emitOpError("affine map num results must equal memref rank");
3003-
if (map.getNumInputs() != numIndexOperands)
3004-
return op->emitOpError("expects as many subscripts as affine map inputs");
3048+
AffineMap map = mapAttr.getValue();
3049+
if (map.getNumResults() != memrefType.getRank())
3050+
return op->emitOpError("affine map num results must equal memref rank");
3051+
if (map.getNumInputs() != numIndexOperands)
3052+
return op->emitOpError("expects as many subscripts as affine map inputs");
30053053

30063054
Region *scope = getAffineScope(op);
30073055
for (auto idx : mapOperands) {

mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,13 +219,14 @@ void AffineDataCopyGeneration::runOnOperation() {
219219

220220
// Promote any single iteration loops in the copy nests and collect
221221
// load/stores to simplify.
222+
IRRewriter rewriter(f.getContext());
222223
SmallVector<Operation *, 4> copyOps;
223224
for (Operation *nest : copyNests)
224225
// With a post order walk, the erasure of loops does not affect
225226
// continuation of the walk or the collection of load/store ops.
226227
nest->walk([&](Operation *op) {
227228
if (auto forOp = dyn_cast<AffineForOp>(op))
228-
(void)promoteIfSingleIteration(forOp);
229+
(void)forOp.promoteIfSingleIteration(rewriter);
229230
else if (isa<AffineLoadOp, AffineStoreOp>(op))
230231
copyOps.push_back(op);
231232
});

mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -457,16 +457,16 @@ void mlir::affine::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
457457
return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) &&
458458
(getSliceIterationCount(sliceTripCountMap) == 1));
459459
};
460+
IRRewriter rewriter(srcForOp.getContext());
460461
// Fix up and if possible, eliminate single iteration loops.
461462
for (AffineForOp forOp : sliceLoops) {
462463
if (isLoopParallelAndContainsReduction(forOp) &&
463464
isInnermostSiblingInsertion && srcIsUnitSlice())
464465
// Patch reduction loop - only ones that are sibling-fused with the
465466
// destination loop - into the parent loop.
466467
(void)promoteSingleIterReductionLoop(forOp, true);
467-
else
468-
// Promote any single iteration slice loops.
469-
(void)promoteIfSingleIteration(forOp);
468+
else // Promote any single iteration slice loops.
469+
(void)forOp.promoteIfSingleIteration(rewriter);
470470
}
471471
}
472472

mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp

Lines changed: 13 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -110,68 +110,6 @@ getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
110110
lb.erase();
111111
}
112112

113-
/// Helper to replace uses of loop carried values (iter_args) and loop
114-
/// yield values while promoting single iteration affine.for ops.
115-
static void replaceIterArgsAndYieldResults(AffineForOp forOp) {
116-
// Replace uses of iter arguments with iter operands (initial values).
117-
auto iterOperands = forOp.getInits();
118-
auto iterArgs = forOp.getRegionIterArgs();
119-
for (auto e : llvm::zip(iterOperands, iterArgs))
120-
std::get<1>(e).replaceAllUsesWith(std::get<0>(e));
121-
122-
// Replace uses of loop results with the values yielded by the loop.
123-
auto outerResults = forOp.getResults();
124-
auto innerResults = forOp.getBody()->getTerminator()->getOperands();
125-
for (auto e : llvm::zip(outerResults, innerResults))
126-
std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
127-
}
128-
129-
/// Promotes the loop body of a forOp to its containing block if the forOp
130-
/// was known to have a single iteration.
131-
LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) {
132-
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
133-
if (!tripCount || *tripCount != 1)
134-
return failure();
135-
136-
// TODO: extend this for arbitrary affine bounds.
137-
if (forOp.getLowerBoundMap().getNumResults() != 1)
138-
return failure();
139-
140-
// Replaces all IV uses to its single iteration value.
141-
auto iv = forOp.getInductionVar();
142-
auto *parentBlock = forOp->getBlock();
143-
if (!iv.use_empty()) {
144-
if (forOp.hasConstantLowerBound()) {
145-
OpBuilder topBuilder(forOp->getParentOfType<func::FuncOp>().getBody());
146-
auto constOp = topBuilder.create<arith::ConstantIndexOp>(
147-
forOp.getLoc(), forOp.getConstantLowerBound());
148-
iv.replaceAllUsesWith(constOp);
149-
} else {
150-
auto lbOperands = forOp.getLowerBoundOperands();
151-
auto lbMap = forOp.getLowerBoundMap();
152-
OpBuilder builder(forOp);
153-
if (lbMap == builder.getDimIdentityMap()) {
154-
// No need of generating an affine.apply.
155-
iv.replaceAllUsesWith(lbOperands[0]);
156-
} else {
157-
auto affineApplyOp =
158-
builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
159-
iv.replaceAllUsesWith(affineApplyOp);
160-
}
161-
}
162-
}
163-
164-
replaceIterArgsAndYieldResults(forOp);
165-
166-
// Move the loop body operations, except for its terminator, to the loop's
167-
// containing block.
168-
forOp.getBody()->back().erase();
169-
parentBlock->getOperations().splice(Block::iterator(forOp),
170-
forOp.getBody()->getOperations());
171-
forOp.erase();
172-
return success();
173-
}
174-
175113
/// Generates an affine.for op with the specified lower and upper bounds
176114
/// while generating the right IV remappings to realize shifts for operations in
177115
/// its body. The operations that go into the loop body are specified in
@@ -218,7 +156,9 @@ static AffineForOp generateShiftedLoop(
218156
for (auto *op : ops)
219157
bodyBuilder.clone(*op, operandMap);
220158
};
221-
if (succeeded(promoteIfSingleIteration(loopChunk)))
159+
160+
IRRewriter rewriter(loopChunk.getContext());
161+
if (succeeded(loopChunk.promoteIfSingleIteration(rewriter)))
222162
return AffineForOp();
223163
return loopChunk;
224164
}
@@ -892,12 +832,13 @@ void mlir::affine::getTileableBands(
892832
/// Unrolls this loop completely.
893833
LogicalResult mlir::affine::loopUnrollFull(AffineForOp forOp) {
894834
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
835+
IRRewriter rewriter(forOp.getContext());
895836
if (mayBeConstantTripCount.has_value()) {
896837
uint64_t tripCount = *mayBeConstantTripCount;
897838
if (tripCount == 0)
898839
return success();
899840
if (tripCount == 1)
900-
return promoteIfSingleIteration(forOp);
841+
return forOp.promoteIfSingleIteration(rewriter);
901842
return loopUnrollByFactor(forOp, tripCount);
902843
}
903844
return failure();
@@ -1003,7 +944,8 @@ static LogicalResult generateCleanupLoopForUnroll(AffineForOp forOp,
1003944

1004945
cleanupForOp.setLowerBound(cleanupOperands, cleanupMap);
1005946
// Promote the loop body up if this has turned into a single iteration loop.
1006-
(void)promoteIfSingleIteration(cleanupForOp);
947+
IRRewriter rewriter(cleanupForOp.getContext());
948+
(void)cleanupForOp.promoteIfSingleIteration(rewriter);
1007949

1008950
// Adjust upper bound of the original loop; this is the same as the lower
1009951
// bound of the cleanup loop.
@@ -1019,10 +961,11 @@ LogicalResult mlir::affine::loopUnrollByFactor(
1019961
bool cleanUpUnroll) {
1020962
assert(unrollFactor > 0 && "unroll factor should be positive");
1021963

964+
IRRewriter rewriter(forOp.getContext());
1022965
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
1023966
if (unrollFactor == 1) {
1024967
if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
1025-
failed(promoteIfSingleIteration(forOp)))
968+
failed(forOp.promoteIfSingleIteration(rewriter)))
1026969
return failure();
1027970
return success();
1028971
}
@@ -1076,7 +1019,7 @@ LogicalResult mlir::affine::loopUnrollByFactor(
10761019
/*iterArgs=*/iterArgs, /*yieldedValues=*/yieldedValues);
10771020

10781021
// Promote the loop body up if this has turned into a single iteration loop.
1079-
(void)promoteIfSingleIteration(forOp);
1022+
(void)forOp.promoteIfSingleIteration(rewriter);
10801023
return success();
10811024
}
10821025

@@ -1135,10 +1078,11 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
11351078
uint64_t unrollJamFactor) {
11361079
assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
11371080

1081+
IRRewriter rewriter(forOp.getContext());
11381082
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
11391083
if (unrollJamFactor == 1) {
11401084
if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
1141-
failed(promoteIfSingleIteration(forOp)))
1085+
failed(forOp.promoteIfSingleIteration(rewriter)))
11421086
return failure();
11431087
return success();
11441088
}
@@ -1198,7 +1142,6 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
11981142
// `unrollJamFactor` copies of its iterOperands, iter_args and yield
11991143
// operands.
12001144
SmallVector<AffineForOp, 4> newLoopsWithIterArgs;
1201-
IRRewriter rewriter(forOp.getContext());
12021145
for (AffineForOp oldForOp : loopsWithIterArgs) {
12031146
SmallVector<Value> dupIterOperands, dupYieldOperands;
12041147
ValueRange oldIterOperands = oldForOp.getInits();
@@ -1321,7 +1264,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
13211264
}
13221265

13231266
// Promote the loop body up if this has turned into a single iteration loop.
1324-
(void)promoteIfSingleIteration(forOp);
1267+
(void)forOp.promoteIfSingleIteration(rewriter);
13251268
return success();
13261269
}
13271270

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,8 @@ void mlir::affine::normalizeAffineParallel(AffineParallelOp op) {
552552

553553
LogicalResult mlir::affine::normalizeAffineFor(AffineForOp op,
554554
bool promoteSingleIter) {
555-
if (promoteSingleIter && succeeded(promoteIfSingleIteration(op)))
555+
IRRewriter rewriter(op.getContext());
556+
if (promoteSingleIter && succeeded(op.promoteIfSingleIteration(rewriter)))
556557
return success();
557558

558559
// Check if the forop is already normalized.

mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,14 @@ void TestAffineDataCopy::runOnOperation() {
107107

108108
// Promote any single iteration loops in the copy nests and simplify
109109
// load/stores.
110+
IRRewriter rewriter(&getContext());
110111
SmallVector<Operation *, 4> copyOps;
111112
for (Operation *nest : copyNests) {
112113
// With a post order walk, the erasure of loops does not affect
113114
// continuation of the walk or the collection of load/store ops.
114115
nest->walk([&](Operation *op) {
115116
if (auto forOp = dyn_cast<AffineForOp>(op))
116-
(void)promoteIfSingleIteration(forOp);
117+
(void)forOp.promoteIfSingleIteration(rewriter);
117118
else if (auto loadOp = dyn_cast<AffineLoadOp>(op))
118119
copyOps.push_back(loadOp);
119120
else if (auto storeOp = dyn_cast<AffineStoreOp>(op))

0 commit comments

Comments
 (0)