Skip to content

Commit 0d0a30b

Browse files
AviadCoAlexisPerry
authored andcommitted
Introduce new Unroll And Jam loop transform for SCF/Affine loops (llvm#94142)
Unroll And Jam was supported in affine dialect long time ago using pass. This commit exposes the pattern using transform and in addition adds partial support for SCF loops.
1 parent aadcd64 commit 0d0a30b

File tree

8 files changed

+633
-58
lines changed

8 files changed

+633
-58
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,41 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
269269
This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
270270
in the return. If all the operations referred to by the `target` operand
271271
unroll properly, the transform succeeds. Otherwise the transform produces a
272-
silencebale failure.
272+
silenceable failure.
273+
274+
Does not return handles as the operation may result in the loop being
275+
removed after a full unrolling.
276+
}];
277+
278+
let arguments = (ins TransformHandleTypeInterface:$target,
279+
ConfinedAttr<I64Attr, [IntPositive]>:$factor);
280+
281+
let assemblyFormat = "$target attr-dict `:` type($target)";
282+
283+
let extraClassDeclaration = [{
284+
::mlir::DiagnosedSilenceableFailure applyToOne(
285+
::mlir::transform::TransformRewriter &rewriter,
286+
::mlir::Operation *target,
287+
::mlir::transform::ApplyToEachResultList &results,
288+
::mlir::transform::TransformState &state);
289+
}];
290+
}
291+
292+
def LoopUnrollAndJamOp : Op<Transform_Dialect, "loop.unroll_and_jam",
293+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
294+
TransformOpInterface, TransformEachOpTrait]> {
295+
let summary = "Unrolls and jam the given loop with the given unroll factor";
296+
let description = [{
297+
Unrolls & jams each loop associated with the given handle to have up to the given
298+
number of loop body copies per iteration. If the unroll factor is larger
299+
than the loop trip count, the latter is used as the unroll factor instead.
300+
301+
#### Return modes
302+
303+
This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
304+
in the return. If all the operations referred to by the `target` operand
305+
unroll properly, the transform succeeds. Otherwise the transform produces a
306+
silenceable failure.
273307

274308
Does not return handles as the operation may result in the loop being
275309
removed after a full unrolling.

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,14 @@ LogicalResult loopUnrollByFactor(
120120
scf::ForOp forOp, uint64_t unrollFactor,
121121
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
122122

123+
/// Unrolls and jams this `scf.for` operation by the specified unroll factor.
124+
/// Returns failure if the loop cannot be unrolled either due to restrictions or
125+
/// due to invalid unroll factors. In case of unroll factor of 1, the function
126+
/// bails out without doing anything (returns success). Currently, only constant
127+
/// trip count that are divided by the unroll factor is supported. Currently,
128+
/// for operations with results are not supported.
129+
LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor);
130+
123131
/// Transform a loop with a strictly positive step
124132
/// for %i = %lb to %ub step %s
125133
/// into a 0-based loop with step 1

mlir/include/mlir/Interfaces/LoopLikeInterface.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,39 @@ class HasParallelRegion : public TraitBase<ConcreteType, HasParallelRegion> {
4848
};
4949

5050
} // namespace OpTrait
51+
52+
// Gathers all maximal sub-blocks of operations that do not themselves
53+
// include a `OpTy` (an operation could have a descendant `OpTy` though
54+
// in its tree). Ignores the block terminators.
55+
template <typename OpTy>
56+
struct JamBlockGatherer {
57+
// Store iterators to the first and last op of each sub-block found.
58+
SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
59+
60+
// This is a linear time walk.
61+
void walk(Operation *op) {
62+
for (Region &region : op->getRegions())
63+
for (Block &block : region)
64+
walk(block);
65+
}
66+
67+
void walk(Block &block) {
68+
assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
69+
"expected block to have a terminator");
70+
for (Block::iterator it = block.begin(), e = std::prev(block.end());
71+
it != e;) {
72+
Block::iterator subBlockStart = it;
73+
while (it != e && !isa<OpTy>(&*it))
74+
++it;
75+
if (it != subBlockStart)
76+
subBlocks.emplace_back(subBlockStart, std::prev(it));
77+
// Process all for ops that appear next.
78+
while (it != e && isa<OpTy>(&*it))
79+
walk(&*it++);
80+
}
81+
}
82+
};
83+
5184
} // namespace mlir
5285

5386
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,6 @@ using namespace affine;
3737
using namespace presburger;
3838
using llvm::SmallMapVector;
3939

40-
namespace {
41-
// This structure is to pass and return sets of loop parameters without
42-
// confusing the order.
43-
struct LoopParams {
44-
Value lowerBound;
45-
Value upperBound;
46-
Value step;
47-
};
48-
} // namespace
49-
5040
/// Computes the cleanup loop lower bound of the loop being unrolled with
5141
/// the specified unroll factor; this bound will also be upper bound of the main
5242
/// part of the unrolled loop. Computes the bound as an AffineMap with its
@@ -1101,34 +1091,6 @@ static bool areInnerBoundsInvariant(AffineForOp forOp) {
11011091
return !walkResult.wasInterrupted();
11021092
}
11031093

1104-
// Gathers all maximal sub-blocks of operations that do not themselves
1105-
// include a for op (a operation could have a descendant for op though
1106-
// in its tree). Ignore the block terminators.
1107-
struct JamBlockGatherer {
1108-
// Store iterators to the first and last op of each sub-block found.
1109-
std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
1110-
1111-
// This is a linear time walk.
1112-
void walk(Operation *op) {
1113-
for (auto &region : op->getRegions())
1114-
for (auto &block : region)
1115-
walk(block);
1116-
}
1117-
1118-
void walk(Block &block) {
1119-
for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
1120-
auto subBlockStart = it;
1121-
while (it != e && !isa<AffineForOp>(&*it))
1122-
++it;
1123-
if (it != subBlockStart)
1124-
subBlocks.emplace_back(subBlockStart, std::prev(it));
1125-
// Process all for ops that appear next.
1126-
while (it != e && isa<AffineForOp>(&*it))
1127-
walk(&*it++);
1128-
}
1129-
}
1130-
};
1131-
11321094
/// Unrolls and jams this loop by the specified factor.
11331095
LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
11341096
uint64_t unrollJamFactor) {
@@ -1158,7 +1120,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
11581120
return failure();
11591121

11601122
// Gather all sub-blocks to jam upon the loop being unrolled.
1161-
JamBlockGatherer jbg;
1123+
JamBlockGatherer<AffineForOp> jbg;
11621124
jbg.walk(forOp);
11631125
auto &subBlocks = jbg.subBlocks;
11641126

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -348,12 +348,36 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
348348
result = loopUnrollByFactor(scfFor, getFactor());
349349
else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
350350
result = loopUnrollByFactor(affineFor, getFactor());
351+
else
352+
return emitSilenceableError()
353+
<< "failed to unroll, incorrect type of payload";
354+
355+
if (failed(result))
356+
return emitSilenceableError() << "failed to unroll";
357+
358+
return DiagnosedSilenceableFailure::success();
359+
}
360+
361+
//===----------------------------------------------------------------------===//
362+
// LoopUnrollAndJamOp
363+
//===----------------------------------------------------------------------===//
364+
365+
DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne(
366+
transform::TransformRewriter &rewriter, Operation *op,
367+
transform::ApplyToEachResultList &results,
368+
transform::TransformState &state) {
369+
LogicalResult result(failure());
370+
if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
371+
result = loopUnrollJamByFactor(scfFor, getFactor());
372+
else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
373+
result = loopUnrollJamByFactor(affineFor, getFactor());
374+
else
375+
return emitSilenceableError()
376+
<< "failed to unroll and jam, incorrect type of payload";
377+
378+
if (failed(result))
379+
return emitSilenceableError() << "failed to unroll and jam";
351380

352-
if (failed(result)) {
353-
DiagnosedSilenceableFailure diag = emitSilenceableError()
354-
<< "failed to unroll";
355-
return diag;
356-
}
357381
return DiagnosedSilenceableFailure::success();
358382
}
359383

0 commit comments

Comments
 (0)