Skip to content

Commit 2a99e70

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] NFC: Add utility function to tile, fuse and set marker to use loop.parallel.
This change is NFC since the facility to tile and generate loop.parallel loops already exists in Linalg. Differential Revision: https://reviews.llvm.org/D77965
1 parent 3b2f26a commit 2a99e70

File tree

2 files changed

+48
-7
lines changed

2 files changed

+48
-7
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,18 @@ LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op,
6363
ArrayRef<int64_t> sizes,
6464
StringRef linalgMarker,
6565
ArrayRef<unsigned> permutation);
66+
LogicalResult tileLinalgOpToParallelLoopsAndSetMarker(
67+
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
68+
StringRef linalgMarker, ArrayRef<unsigned> permutation);
6669

6770
/// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and
6871
/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`.
6972
LogicalResult tileAndFuseLinalgOpAndSetMarker(
7073
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
7174
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker);
75+
LogicalResult tileAndFuseLinalgOpToParallelLoopsAndSetMarker(
76+
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
77+
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker);
7278

7379
using LinalgLoops = SmallVector<Operation *, 4>;
7480

mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,43 @@ using llvm::SetVector;
4040
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
4141
"__internal_linalg_transform__";
4242

43-
LogicalResult mlir::linalg::tileLinalgOpAndSetMarker(
44-
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
45-
StringRef linalgMarker, ArrayRef<unsigned> permutation) {
43+
using TileFn = Optional<TiledLinalgOp>(OpBuilder &, LinalgOp, ArrayRef<int64_t>,
44+
ArrayRef<unsigned>, OperationFolder *);
45+
46+
static LogicalResult
47+
tileLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter,
48+
Operation *op, ArrayRef<int64_t> sizes,
49+
StringRef linalgMarker,
50+
ArrayRef<unsigned> permutation) {
4651
assert(permutation.empty() || permutation.size() == sizes.size());
47-
auto tileRes = tileLinalgOperation(rewriter, op, sizes, permutation);
52+
auto tileRes = tileFn(rewriter, op, sizes, permutation, /*folder=*/nullptr);
4853
if (!tileRes)
4954
return failure();
5055
tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker,
5156
rewriter.getStringAttr(linalgMarker));
5257
return success();
5358
}
5459

55-
LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker(
60+
LogicalResult mlir::linalg::tileLinalgOpAndSetMarker(
5661
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
57-
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker) {
58-
auto tileRes = tileLinalgOperation(rewriter, op, sizes);
62+
StringRef linalgMarker, ArrayRef<unsigned> permutation) {
63+
return tileLinalgOpAndSetMarkerImpl(tileLinalgOp, rewriter, op, sizes,
64+
linalgMarker, permutation);
65+
}
66+
LogicalResult mlir::linalg::tileLinalgOpToParallelLoopsAndSetMarker(
67+
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
68+
StringRef linalgMarker, ArrayRef<unsigned> permutation) {
69+
return tileLinalgOpAndSetMarkerImpl(tileLinalgOpToParallelLoops, rewriter, op,
70+
sizes, linalgMarker, permutation);
71+
}
72+
73+
static LogicalResult
74+
tileAndFuseLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter,
75+
Operation *op, ArrayRef<int64_t> sizes,
76+
ArrayRef<int64_t> operandIndicesToFuse,
77+
StringRef linalgMarker) {
78+
auto tileRes =
79+
tileFn(rewriter, op, sizes, /*permutation=*/{}, /*folder=*/nullptr);
5980
if (!tileRes)
6081
return failure();
6182
tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker,
@@ -89,6 +110,20 @@ LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker(
89110
return success();
90111
}
91112

113+
LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker(
114+
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
115+
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker) {
116+
return tileAndFuseLinalgOpAndSetMarkerImpl(
117+
tileLinalgOp, rewriter, op, sizes, operandIndicesToFuse, linalgMarker);
118+
}
119+
LogicalResult mlir::linalg::tileAndFuseLinalgOpToParallelLoopsAndSetMarker(
120+
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
121+
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker) {
122+
return tileAndFuseLinalgOpAndSetMarkerImpl(
123+
tileLinalgOpToParallelLoops, rewriter, op, sizes, operandIndicesToFuse,
124+
linalgMarker);
125+
}
126+
92127
bool mlir::linalg::detail::isProducedByOpOfTypeImpl(
93128
Operation *consumerOp, Value consumedView,
94129
function_ref<bool(Operation *)> isaOpType) {

0 commit comments

Comments
 (0)