Skip to content

Commit e02c2ab

Browse files
author
Aviad Cohen
committed
Introduce new Unroll And Jam loop transform for SCF/Affine loops
Unroll And Jam was supported in affine dialect long time ago using pass. This commit exposes the pattern using transform and in addition adds partial support for SCF loops.
1 parent 47f8b85 commit e02c2ab

File tree

8 files changed

+633
-58
lines changed

8 files changed

+633
-58
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,41 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
269269
This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
270270
in the return. If all the operations referred to by the `target` operand
271271
unroll properly, the transform succeeds. Otherwise the transform produces a
272-
silencebale failure.
272+
silenceable failure.
273+
274+
Does not return handles as the operation may result in the loop being
275+
removed after a full unrolling.
276+
}];
277+
278+
let arguments = (ins TransformHandleTypeInterface:$target,
279+
ConfinedAttr<I64Attr, [IntPositive]>:$factor);
280+
281+
let assemblyFormat = "$target attr-dict `:` type($target)";
282+
283+
let extraClassDeclaration = [{
284+
::mlir::DiagnosedSilenceableFailure applyToOne(
285+
::mlir::transform::TransformRewriter &rewriter,
286+
::mlir::Operation *target,
287+
::mlir::transform::ApplyToEachResultList &results,
288+
::mlir::transform::TransformState &state);
289+
}];
290+
}
291+
292+
def LoopUnrollAndJamOp : Op<Transform_Dialect, "loop.unroll_and_jam",
293+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
294+
TransformOpInterface, TransformEachOpTrait]> {
295+
let summary = "Unrolls and jam the given loop with the given unroll factor";
296+
let description = [{
297+
Unrolls & jams each loop associated with the given handle to have up to the given
298+
number of loop body copies per iteration. If the unroll factor is larger
299+
than the loop trip count, the latter is used as the unroll factor instead.
300+
301+
#### Return modes
302+
303+
This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
304+
in the return. If all the operations referred to by the `target` operand
305+
unroll properly, the transform succeeds. Otherwise the transform produces a
306+
silenceable failure.
273307

274308
Does not return handles as the operation may result in the loop being
275309
removed after a full unrolling.

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,14 @@ LogicalResult loopUnrollByFactor(
120120
scf::ForOp forOp, uint64_t unrollFactor,
121121
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
122122

123+
/// Unrolls and jams this `scf.for` operation by the specified unroll factor.
124+
/// Returns failure if the loop cannot be unrolled either due to restrictions or
125+
/// due to invalid unroll factors. In case of unroll factor of 1, the function
126+
/// bails out without doing anything (returns success). Currently, only constant
127+
/// trip count that are divided by the unroll factor is supported. Currently,
128+
/// for operations with results are not supported.
129+
LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor);
130+
123131
/// Transform a loop with a strictly positive step
124132
/// for %i = %lb to %ub step %s
125133
/// into a 0-based loop with step 1

mlir/include/mlir/Interfaces/LoopLikeInterface.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,39 @@ class HasParallelRegion : public TraitBase<ConcreteType, HasParallelRegion> {
4848
};
4949

5050
} // namespace OpTrait
51+
52+
// Gathers all maximal sub-blocks of operations that do not themselves
53+
// include a `OpTy` (an operation could have a descendant `OpTy` though
54+
// in its tree). Ignores the block terminators.
55+
template <typename OpTy>
56+
struct JamBlockGatherer {
57+
// Store iterators to the first and last op of each sub-block found.
58+
SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
59+
60+
// This is a linear time walk.
61+
void walk(Operation *op) {
62+
for (Region &region : op->getRegions())
63+
for (Block &block : region)
64+
walk(block);
65+
}
66+
67+
void walk(Block &block) {
68+
assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
69+
"expected block to have a terminator");
70+
for (Block::iterator it = block.begin(), e = std::prev(block.end());
71+
it != e;) {
72+
Block::iterator subBlockStart = it;
73+
while (it != e && !isa<OpTy>(&*it))
74+
++it;
75+
if (it != subBlockStart)
76+
subBlocks.emplace_back(subBlockStart, std::prev(it));
77+
// Process all for ops that appear next.
78+
while (it != e && isa<OpTy>(&*it))
79+
walk(&*it++);
80+
}
81+
}
82+
};
83+
5184
} // namespace mlir
5285

5386
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,6 @@ using namespace affine;
3737
using namespace presburger;
3838
using llvm::SmallMapVector;
3939

40-
namespace {
41-
// This structure is to pass and return sets of loop parameters without
42-
// confusing the order.
43-
struct LoopParams {
44-
Value lowerBound;
45-
Value upperBound;
46-
Value step;
47-
};
48-
} // namespace
49-
5040
/// Computes the cleanup loop lower bound of the loop being unrolled with
5141
/// the specified unroll factor; this bound will also be upper bound of the main
5242
/// part of the unrolled loop. Computes the bound as an AffineMap with its
@@ -1101,34 +1091,6 @@ static bool areInnerBoundsInvariant(AffineForOp forOp) {
11011091
return !walkResult.wasInterrupted();
11021092
}
11031093

1104-
// Gathers all maximal sub-blocks of operations that do not themselves
1105-
// include a for op (a operation could have a descendant for op though
1106-
// in its tree). Ignore the block terminators.
1107-
struct JamBlockGatherer {
1108-
// Store iterators to the first and last op of each sub-block found.
1109-
std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
1110-
1111-
// This is a linear time walk.
1112-
void walk(Operation *op) {
1113-
for (auto &region : op->getRegions())
1114-
for (auto &block : region)
1115-
walk(block);
1116-
}
1117-
1118-
void walk(Block &block) {
1119-
for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
1120-
auto subBlockStart = it;
1121-
while (it != e && !isa<AffineForOp>(&*it))
1122-
++it;
1123-
if (it != subBlockStart)
1124-
subBlocks.emplace_back(subBlockStart, std::prev(it));
1125-
// Process all for ops that appear next.
1126-
while (it != e && isa<AffineForOp>(&*it))
1127-
walk(&*it++);
1128-
}
1129-
}
1130-
};
1131-
11321094
/// Unrolls and jams this loop by the specified factor.
11331095
LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
11341096
uint64_t unrollJamFactor) {
@@ -1158,7 +1120,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
11581120
return failure();
11591121

11601122
// Gather all sub-blocks to jam upon the loop being unrolled.
1161-
JamBlockGatherer jbg;
1123+
JamBlockGatherer<AffineForOp> jbg;
11621124
jbg.walk(forOp);
11631125
auto &subBlocks = jbg.subBlocks;
11641126

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -348,12 +348,36 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
348348
result = loopUnrollByFactor(scfFor, getFactor());
349349
else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
350350
result = loopUnrollByFactor(affineFor, getFactor());
351+
else
352+
return emitSilenceableError()
353+
<< "failed to unroll, incorrect type of payload";
354+
355+
if (failed(result))
356+
return emitSilenceableError() << "failed to unroll";
357+
358+
return DiagnosedSilenceableFailure::success();
359+
}
360+
361+
//===----------------------------------------------------------------------===//
362+
// LoopUnrollAndJamOp
363+
//===----------------------------------------------------------------------===//
364+
365+
DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne(
366+
transform::TransformRewriter &rewriter, Operation *op,
367+
transform::ApplyToEachResultList &results,
368+
transform::TransformState &state) {
369+
LogicalResult result(failure());
370+
if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
371+
result = loopUnrollJamByFactor(scfFor, getFactor());
372+
else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
373+
result = loopUnrollJamByFactor(affineFor, getFactor());
374+
else
375+
return emitSilenceableError()
376+
<< "failed to unroll and jam, incorrect type of payload";
377+
378+
if (failed(result))
379+
return emitSilenceableError() << "failed to unroll and jam";
351380

352-
if (failed(result)) {
353-
DiagnosedSilenceableFailure diag = emitSilenceableError()
354-
<< "failed to unroll";
355-
return diag;
356-
}
357381
return DiagnosedSilenceableFailure::success();
358382
}
359383

0 commit comments

Comments
 (0)