Skip to content

Commit f52e7b2

Browse files
author
Aviad Cohen
committed
Introduce new Unroll And Jam loop transform for SCF/Affine loops
Unroll And Jam was supported in affine dialect long time ago using pass. This commit exposes the pattern using transform and in addition adds partial support for SCF loops.
1 parent de37c06 commit f52e7b2

File tree

8 files changed

+646
-37
lines changed

8 files changed

+646
-37
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,41 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
243243
This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
244244
in the return. If all the operations referred to by the `target` operand
245245
unroll properly, the transform succeeds. Otherwise the transform produces a
246-
silencebale failure.
246+
silenceable failure.
247+
248+
Does not return handles as the operation may result in the loop being
249+
removed after a full unrolling.
250+
}];
251+
252+
let arguments = (ins TransformHandleTypeInterface:$target,
253+
ConfinedAttr<I64Attr, [IntPositive]>:$factor);
254+
255+
let assemblyFormat = "$target attr-dict `:` type($target)";
256+
257+
let extraClassDeclaration = [{
258+
::mlir::DiagnosedSilenceableFailure applyToOne(
259+
::mlir::transform::TransformRewriter &rewriter,
260+
::mlir::Operation *target,
261+
::mlir::transform::ApplyToEachResultList &results,
262+
::mlir::transform::TransformState &state);
263+
}];
264+
}
265+
266+
def LoopUnrollAndJamOp : Op<Transform_Dialect, "loop.unroll_and_jam",
267+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
268+
TransformOpInterface, TransformEachOpTrait]> {
269+
let summary = "Unrolls and jam the given loop with the given unroll factor";
270+
let description = [{
271+
Unrolls & jams each loop associated with the given handle to have up to the given
272+
number of loop body copies per iteration. If the unroll factor is larger
273+
than the loop trip count, the latter is used as the unroll factor instead.
274+
275+
#### Return modes
276+
277+
This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
278+
in the return. If all the operations referred to by the `target` operand
279+
unroll properly, the transform succeeds. Otherwise the transform produces a
280+
silenceable failure.
247281

248282
Does not return handles as the operation may result in the loop being
249283
removed after a full unrolling.

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,14 @@ LogicalResult loopUnrollByFactor(
120120
scf::ForOp forOp, uint64_t unrollFactor,
121121
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
122122

123+
/// Unrolls and jam this for operation by the specified unroll factor. Returns
124+
/// failure if the loop cannot be unrolled either due to restrictions or due to
125+
/// invalid unroll factors. In case of unroll factor of 1, the function bails
126+
/// out without doing anything (returns success). Currently, only constant trip
127+
/// count that are divided by the unroll factor is supported. Currently, for
128+
/// operations with results are not supported.
129+
LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor);
130+
123131
/// Tile a nest of standard for loops rooted at `rootForOp` by finding such
124132
/// parametric tile sizes that the outer loops have a fixed number of iterations
125133
/// as defined in `sizes`.

mlir/include/mlir/IR/LoopUtils.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
//===- LoopUtils.h - LoopUtils Support ---------------------*- C++
2+
//-*-=============//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This file contains definitions for the action framework. This framework
11+
// allows for external entities to control certain actions taken by the compiler
12+
// by registering handler functions.
13+
//
14+
//===----------------------------------------------------------------------===//
15+
16+
#ifndef MLIR_IR_LOOP_UTILS_H
17+
#define MLIR_IR_LOOP_UTILS_H
18+
19+
#include "mlir/IR/BuiltinOps.h"
20+
21+
namespace mlir {
22+
23+
// Gathers all maximal sub-blocks of operations that do not themselves
24+
// include a `OpTy` (an operation could have a descendant `OpTy` though
25+
// in its tree). Ignore the block terminators.
26+
template <typename OpTy>
27+
struct JamBlockGatherer {
28+
// Store iterators to the first and last op of each sub-block found.
29+
llvm::SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
30+
31+
// This is a linear time walk.
32+
void walk(Operation *op) {
33+
for (auto &region : op->getRegions())
34+
for (auto &block : region)
35+
walk(block);
36+
}
37+
38+
void walk(Block &block) {
39+
assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
40+
"expected block to have a terminator");
41+
for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
42+
auto subBlockStart = it;
43+
while (it != e && !isa<OpTy>(&*it))
44+
++it;
45+
if (it != subBlockStart)
46+
subBlocks.emplace_back(subBlockStart, std::prev(it));
47+
// Process all for ops that appear next.
48+
while (it != e && isa<OpTy>(&*it))
49+
walk(&*it++);
50+
}
51+
}
52+
};
53+
54+
} // namespace mlir
55+
56+
#endif // MLIR_IR_LOOP_UTILS_H

mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/Dialect/SCF/IR/SCF.h"
2424
#include "mlir/IR/IRMapping.h"
2525
#include "mlir/IR/IntegerSet.h"
26+
#include "mlir/IR/LoopUtils.h"
2627
#include "mlir/Support/MathExtras.h"
2728
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2829
#include "llvm/ADT/MapVector.h"
@@ -1102,34 +1103,6 @@ static bool areInnerBoundsInvariant(AffineForOp forOp) {
11021103
return !walkResult.wasInterrupted();
11031104
}
11041105

1105-
// Gathers all maximal sub-blocks of operations that do not themselves
1106-
// include a for op (a operation could have a descendant for op though
1107-
// in its tree). Ignore the block terminators.
1108-
struct JamBlockGatherer {
1109-
// Store iterators to the first and last op of each sub-block found.
1110-
std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
1111-
1112-
// This is a linear time walk.
1113-
void walk(Operation *op) {
1114-
for (auto &region : op->getRegions())
1115-
for (auto &block : region)
1116-
walk(block);
1117-
}
1118-
1119-
void walk(Block &block) {
1120-
for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
1121-
auto subBlockStart = it;
1122-
while (it != e && !isa<AffineForOp>(&*it))
1123-
++it;
1124-
if (it != subBlockStart)
1125-
subBlocks.emplace_back(subBlockStart, std::prev(it));
1126-
// Process all for ops that appear next.
1127-
while (it != e && isa<AffineForOp>(&*it))
1128-
walk(&*it++);
1129-
}
1130-
}
1131-
};
1132-
11331106
/// Unrolls and jams this loop by the specified factor.
11341107
LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
11351108
uint64_t unrollJamFactor) {
@@ -1159,7 +1132,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
11591132
return failure();
11601133

11611134
// Gather all sub-blocks to jam upon the loop being unrolled.
1162-
JamBlockGatherer jbg;
1135+
JamBlockGatherer<AffineForOp> jbg;
11631136
jbg.walk(forOp);
11641137
auto &subBlocks = jbg.subBlocks;
11651138

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,28 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
304304
return DiagnosedSilenceableFailure::success();
305305
}
306306

307+
//===----------------------------------------------------------------------===//
308+
// LoopUnrollAndJamOp
309+
//===----------------------------------------------------------------------===//
310+
311+
DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne(
312+
transform::TransformRewriter &rewriter, Operation *op,
313+
transform::ApplyToEachResultList &results,
314+
transform::TransformState &state) {
315+
LogicalResult result(failure());
316+
if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
317+
result = loopUnrollJamByFactor(scfFor, getFactor());
318+
else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
319+
result = loopUnrollJamByFactor(affineFor, getFactor());
320+
321+
if (failed(result)) {
322+
DiagnosedSilenceableFailure diag = emitSilenceableError()
323+
<< "failed to unroll and jam";
324+
return diag;
325+
}
326+
return DiagnosedSilenceableFailure::success();
327+
}
328+
307329
//===----------------------------------------------------------------------===//
308330
// LoopCoalesceOp
309331
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)