1
- // ===- PadOpInterchange.cpp - Interchange pad operation with Generic ops - -===//
1
+ // ===- PadOpInterchange.cpp - Interchange tensor. pad with linalg producer -===//
2
2
//
3
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
//
9
- // This file implements patterns that intechanges a generic op -> pad_tensor
10
- // pattern into extract_slice -> generic_op.
9
+ // This file implements patterns that intechanges a linalg.generic -> tensor.pad
10
+ // op chain into a tensor.extract_slice -> linalg.generic -> tensor.insert_slice
11
+ // op chain.
11
12
//
12
13
// ===----------------------------------------------------------------------===//
13
14
17
18
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
18
19
19
20
using namespace mlir ;
20
- using namespace mlir ::linalg;
21
21
22
22
namespace {
23
23
24
24
// / A sequence of operations
25
25
// /
26
26
// / ```mlir
27
27
// / %0 = linalg. ...
28
- // / %1 = linalg.pad_tensor %0 ...
28
+ // / %1 = tensor.pad %0 ...
29
29
// / ```
30
30
// /
31
31
// / can be replaced with
@@ -40,6 +40,7 @@ namespace {
40
40
// / if the `linalg.generic` has all parallel iterator types.
41
41
struct FusePadOp : OpRewritePattern<tensor::PadOp> {
42
42
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
43
+
43
44
LogicalResult matchAndRewrite (tensor::PadOp padOp,
44
45
PatternRewriter &rewriter) const override {
45
46
// Only works on padding op that sets the padded value to a constant.
@@ -50,7 +51,7 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
50
51
// This pattern could work for any Linalg op. For now restrict it to generic
51
52
// ops.
52
53
Value source = padOp.source ();
53
- auto linalgOp = source.getDefiningOp <GenericOp>();
54
+ auto linalgOp = source.getDefiningOp <linalg:: GenericOp>();
54
55
if (!linalgOp) {
55
56
return rewriter.notifyMatchFailure (
56
57
padOp, " expected source to be linalg.generic op" );
@@ -75,14 +76,14 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
75
76
// Create the tensor of same size as output of the pad op.
76
77
RankedTensorType padResultType = padOp.getResultType ();
77
78
auto resultSizes = getAsOpFoldResult (resultShape[0 ]);
78
- auto initTensor = rewriter.create <InitTensorOp>(
79
+ auto initTensor = rewriter.create <linalg:: InitTensorOp>(
79
80
loc, resultSizes, padResultType.getElementType ());
80
81
81
82
// Fill the tensor with the pad value.
82
83
// TODO: There is an option to fill only the boundaries. For now just
83
84
// filling the whole tensor.
84
85
auto fillTensor =
85
- rewriter.create <FillOp>(loc, padValue, initTensor.getResult ());
86
+ rewriter.create <linalg:: FillOp>(loc, padValue, initTensor.getResult ());
86
87
87
88
// Construct a slice of the fill result that is to be replaced with the
88
89
// result of the generic op. The low pad values are the offsets, the size of
@@ -107,7 +108,8 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
107
108
loc, fillTensor.getResult (0 ), offsets, sizes, strides);
108
109
109
110
// Clone the generic op.
110
- auto clonedOp = cast<GenericOp>(rewriter.clone (*linalgOp.getOperation ()));
111
+ auto clonedOp =
112
+ cast<linalg::GenericOp>(rewriter.clone (*linalgOp.getOperation ()));
111
113
clonedOp.setOutputOperand (resultNumber, slice.getResult ());
112
114
113
115
// Insert it back into the result of the fill.
@@ -119,7 +121,7 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
119
121
};
120
122
} // namespace
121
123
122
- void mlir::linalg::populateFusePadTensorWithProducerLinalgOpPatterns (
124
+ void mlir::linalg::populateFuseTensorPadWithProducerLinalgOpPatterns (
123
125
RewritePatternSet &patterns) {
124
126
patterns.add <FusePadOp>(patterns.getContext ());
125
127
}
0 commit comments