Skip to content

Commit 7d249df

Browse files
committed
[mlir][linalg] NFC: minor cleanups after moving pad to tensor dialect
Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D120627
1 parent 5aeaabf commit 7d249df

File tree

3 files changed

+16
-15
lines changed

3 files changed

+16
-15
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,9 @@ void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
103103
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
104104
RewritePatternSet &patterns);
105105

106-
/// Pattern to fuse a `linalg.pad_tensor` operation with the producer of its
107-
/// source, if the producer is a `linalg` operation with all parallel iterator
108-
/// types.
109-
void populateFusePadTensorWithProducerLinalgOpPatterns(
106+
/// Pattern to fuse a `tensor.pad` operation with the producer of its source,
107+
/// if the producer is a `linalg` operation with all parallel iterator types.
108+
void populateFuseTensorPadWithProducerLinalgOpPatterns(
110109
RewritePatternSet &patterns);
111110

112111
/// Patterns to convert from one named op to another. These can be seen as

mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
//===- PadOpInterchange.cpp - Interchange pad operation with Generic ops --===//
1+
//===- PadOpInterchange.cpp - Interchange tensor.pad with linalg producer -===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file implements patterns that intechanges a generic op -> pad_tensor
10-
// pattern into extract_slice -> generic_op.
9+
// This file implements patterns that intechanges a linalg.generic -> tensor.pad
10+
// op chain into a tensor.extract_slice -> linalg.generic -> tensor.insert_slice
11+
// op chain.
1112
//
1213
//===----------------------------------------------------------------------===//
1314

@@ -17,15 +18,14 @@
1718
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1819

1920
using namespace mlir;
20-
using namespace mlir::linalg;
2121

2222
namespace {
2323

2424
/// A sequence of operations
2525
///
2626
/// ```mlir
2727
/// %0 = linalg. ...
28-
/// %1 = linalg.pad_tensor %0 ...
28+
/// %1 = tensor.pad %0 ...
2929
/// ```
3030
///
3131
/// can be replaced with
@@ -40,6 +40,7 @@ namespace {
4040
/// if the `linalg.generic` has all parallel iterator types.
4141
struct FusePadOp : OpRewritePattern<tensor::PadOp> {
4242
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
43+
4344
LogicalResult matchAndRewrite(tensor::PadOp padOp,
4445
PatternRewriter &rewriter) const override {
4546
// Only works on padding op that sets the padded value to a constant.
@@ -50,7 +51,7 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
5051
// This pattern could work for any Linalg op. For now restrict it to generic
5152
// ops.
5253
Value source = padOp.source();
53-
auto linalgOp = source.getDefiningOp<GenericOp>();
54+
auto linalgOp = source.getDefiningOp<linalg::GenericOp>();
5455
if (!linalgOp) {
5556
return rewriter.notifyMatchFailure(
5657
padOp, "expected source to be linalg.generic op");
@@ -75,14 +76,14 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
7576
// Create the tensor of same size as output of the pad op.
7677
RankedTensorType padResultType = padOp.getResultType();
7778
auto resultSizes = getAsOpFoldResult(resultShape[0]);
78-
auto initTensor = rewriter.create<InitTensorOp>(
79+
auto initTensor = rewriter.create<linalg::InitTensorOp>(
7980
loc, resultSizes, padResultType.getElementType());
8081

8182
// Fill the tensor with the pad value.
8283
// TODO: There is an option to fill only the boundaries. For now just
8384
// filling the whole tensor.
8485
auto fillTensor =
85-
rewriter.create<FillOp>(loc, padValue, initTensor.getResult());
86+
rewriter.create<linalg::FillOp>(loc, padValue, initTensor.getResult());
8687

8788
// Construct a slice of the fill result that is to be replaced with the
8889
// result of the generic op. The low pad values are the offsets, the size of
@@ -107,7 +108,8 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
107108
loc, fillTensor.getResult(0), offsets, sizes, strides);
108109

109110
// Clone the generic op.
110-
auto clonedOp = cast<GenericOp>(rewriter.clone(*linalgOp.getOperation()));
111+
auto clonedOp =
112+
cast<linalg::GenericOp>(rewriter.clone(*linalgOp.getOperation()));
111113
clonedOp.setOutputOperand(resultNumber, slice.getResult());
112114

113115
// Insert it back into the result of the fill.
@@ -119,7 +121,7 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
119121
};
120122
} // namespace
121123

122-
void mlir::linalg::populateFusePadTensorWithProducerLinalgOpPatterns(
124+
void mlir::linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(
123125
RewritePatternSet &patterns) {
124126
patterns.add<FusePadOp>(patterns.getContext());
125127
}

mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ struct TestPadFusionPass
3434
MLIRContext *context = &getContext();
3535
FuncOp funcOp = getOperation();
3636
RewritePatternSet patterns(context);
37-
linalg::populateFusePadTensorWithProducerLinalgOpPatterns(patterns);
37+
linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(patterns);
3838
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
3939
std::move(patterns))))
4040
return signalPassFailure();

0 commit comments

Comments
 (0)