Skip to content

Commit 96eb4d3

Browse files
committed
[mlir][linalg] Extract GeneralizePadOpPattern into a standalone transformation
Currently, `GeneralizePadOpPattern` is grouped under `populatePadOpVectorizationPatterns`. However, as noted in #111349, this transformation "decomposes" rather than "vectorizes" `tensor.pad`. As such, it functions as: * a vectorization _pre-processing_ transformation, not * a vectorization transformation itself. To clarify its purpose, this PR turns `GeneralizePadOpPattern` into a standalone transformation by: * introducing a dedicated `populateDecomposePadPatterns` method, * adding a `apply_patterns.linalg.decompose_pad` Transform Dialect Op, and * removing it from `populatePadOpVectorizationPatterns`. In addition, to better reflect its role, it is renamed as "decomposition" rather then "generalization". That's to better reflect its role. This is in line with the recent renaming of similar ops, i.e. tensor.pack/tensor.unpack Ops in #116439.
1 parent 157d847 commit 96eb4d3

File tree

9 files changed

+50
-20
lines changed

9 files changed

+50
-20
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ def ApplyDecomposeTensorPackUnpackPatternsOp
5252
let assemblyFormat = "attr-dict";
5353
}
5454

55+
def ApplyDecomposeTensorPadPatternsOp
56+
: Op<Transform_Dialect, "apply_patterns.linalg.decompose_pad",
57+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
58+
let description = [{
59+
Collect patterns to decompose tensor.pad into e.g. tensor::EmptyOp,
60+
linalg::FillOp and tensor::InsertSliceOp.
61+
}];
62+
63+
let assemblyFormat = "attr-dict";
64+
}
65+
5566
def ApplyFoldUnitExtentDimsViaReshapesPatternsOp : Op<Transform_Dialect,
5667
"apply_patterns.linalg.fold_unit_extent_dims_via_reshapes",
5768
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1503,8 +1503,8 @@ using OptimizeCopyFn =
15031503

15041504
/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
15051505
/// InsertSliceOp. For now, only constant padding values are supported.
1506-
struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
1507-
GeneralizePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
1506+
struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> {
1507+
DecomposePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
15081508
: OpRewritePattern<tensor::PadOp>(context, benefit) {}
15091509
LogicalResult matchAndRewrite(tensor::PadOp padOp,
15101510
PatternRewriter &rewriter) const override;
@@ -1688,6 +1688,10 @@ void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
16881688
/// outer dims to be unit.
16891689
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns);
16901690

1691+
/// Populates patterns to decompose tensor.pad into e.g.
1692+
/// tensor.empty, linalg.fill, tensor.insert_slice.
1693+
void populateDecomposePadPatterns(RewritePatternSet &patterns);
1694+
16911695
/// Populates patterns to transform linalg.conv_2d_xxx operations into
16921696
/// linalg.generic (for img2col packing) and linalg.matmul.
16931697
/// \see rewriteInIm2Col for more details.

mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,7 @@ using namespace mlir;
2525
//===----------------------------------------------------------------------===//
2626

2727
void mlir::populateTensorToLinalgPatterns(RewritePatternSet &patterns) {
28-
patterns.add<mlir::linalg::GeneralizePadOpPattern>(patterns.getContext());
28+
// TODO: Add the remaining patterns, e.g. to decompose Pack/Unpack Ops.
29+
// Alternatively, delete this file.
30+
patterns.add<mlir::linalg::DecomposePadOpPattern>(patterns.getContext());
2931
}

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,11 @@ void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
234234
linalg::populateDecomposePackUnpackPatterns(patterns);
235235
}
236236

237+
void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
238+
RewritePatternSet &patterns) {
239+
linalg::populateDecomposePadPatterns(patterns);
240+
}
241+
237242
void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
238243
RewritePatternSet &patterns) {
239244
linalg::ControlDropUnitDims options;
@@ -3491,8 +3496,12 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
34913496
// Add misc. vectorization patterns (e.g. for tensor.insert_slice)
34923497
linalg::populateInsertSliceVectorizationPatterns(patterns);
34933498

3494-
if (getVectorizePadding())
3499+
if (getVectorizePadding()) {
34953500
linalg::populatePadOpVectorizationPatterns(patterns);
3501+
// This creates an alternative path for lowering tensor.pad - by
3502+
// decomposing it into e.g. linalg.fill.
3503+
linalg::populateDecomposePadPatterns(patterns);
3504+
}
34963505
vector::populateVectorStepLoweringPatterns(patterns);
34973506

34983507
TrackingListener listener(state, *this);

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -921,7 +921,7 @@ LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
921921

922922
/// Filling `dest` using FillOp constant padding value if possible.
923923
/// Otherwise, generate a tensor::GenerateOp.
924-
Value GeneralizePadOpPattern::createFillOrGenerateOp(
924+
Value DecomposePadOpPattern::createFillOrGenerateOp(
925925
RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
926926
const SmallVector<Value> &dynSizes) const {
927927
auto padValue = padOp.getConstantPaddingValue();
@@ -938,8 +938,8 @@ Value GeneralizePadOpPattern::createFillOrGenerateOp(
938938
}
939939

940940
LogicalResult
941-
GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
942-
PatternRewriter &rewriter) const {
941+
DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
942+
PatternRewriter &rewriter) const {
943943
// Given an OpFoldResult, return an index-typed value.
944944
auto getIdxValue = [&](OpFoldResult ofr) {
945945
if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
@@ -1623,3 +1623,7 @@ void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
16231623
// TODO: Add and test patterns for tensor.unpack
16241624
patterns.add<DecomposeOuterUnitDimsPackOpPattern>(patterns.getContext());
16251625
}
1626+
1627+
void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) {
1628+
patterns.add<DecomposePadOpPattern>(patterns.getContext());
1629+
}

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2770,12 +2770,6 @@ void mlir::linalg::populateInsertSliceVectorizationPatterns(
27702770

27712771
void mlir::linalg::populatePadOpVectorizationPatterns(
27722772
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2773-
// TODO: The following pattern implements "decomposition" and
2774-
// optional "vectorization". Seperate "decomposition" into a sepereate
2775-
// pre-processing pattern group.
2776-
patterns.add<GeneralizePadOpPattern>(patterns.getContext(), baseBenefit);
2777-
2778-
// Try these specialized patterns first before resorting to the generic one.
27792773
patterns.add<PadOpVectorizationWithTransferReadPattern,
27802774
PadOpVectorizationWithTransferWritePattern,
27812775
PadOpVectorizationWithInsertSlicePattern>(

mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir renamed to mlir/test/Dialect/Linalg/decompose-pad-tensor.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-pad-tensor" %s | FileCheck %s
1+
// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-decompose-pad-tensor" %s | FileCheck %s
22

33
// CHECK-LABEL: func @generalize_pad_tensor_static_shape(
44
// CHECK-SAME: %[[IN:.*]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {

mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ module attributes {transform.with_named_sequence} {
202202
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
203203

204204
transform.apply_patterns to %func_op {
205+
// TODO: Split into two tests, one for each pattern
206+
transform.apply_patterns.linalg.decompose_pad
205207
transform.apply_patterns.linalg.pad_vectorization
206208
} : !transform.op<"func.func">
207209
transform.yield
@@ -236,6 +238,8 @@ module attributes {transform.with_named_sequence} {
236238
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
237239

238240
transform.apply_patterns to %func_op {
241+
// TODO: Split into two tests, one for each pattern
242+
transform.apply_patterns.linalg.decompose_pad
239243
transform.apply_patterns.linalg.pad_vectorization
240244
} : !transform.op<"func.func">
241245
transform.yield
@@ -270,6 +274,8 @@ module attributes {transform.with_named_sequence} {
270274
%func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
271275

272276
transform.apply_patterns to %func_op {
277+
// TODO: Split into two tests, one for each pattern
278+
transform.apply_patterns.linalg.decompose_pad
273279
transform.apply_patterns.linalg.pad_vectorization
274280
} : !transform.op<"func.func">
275281
transform.yield

mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ struct TestLinalgTransforms
7070
llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
7171
"in vector.contract form"),
7272
llvm::cl::init(false)};
73-
Option<bool> testGeneralizePadTensor{
74-
*this, "test-generalize-pad-tensor",
73+
Option<bool> testDecomposePadTensor{
74+
*this, "test-decompose-pad-tensor",
7575
llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
7676
llvm::cl::init(false)};
7777
Option<bool> testDecomposeTensorPackOp{
@@ -166,9 +166,9 @@ static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
166166
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
167167
}
168168

169-
static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) {
169+
static void applyDecomposePadPatterns(func::FuncOp funcOp) {
170170
RewritePatternSet patterns(funcOp.getContext());
171-
patterns.add<GeneralizePadOpPattern>(funcOp.getContext());
171+
patterns.add<DecomposePadOpPattern>(funcOp.getContext());
172172
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
173173
}
174174

@@ -235,8 +235,8 @@ void TestLinalgTransforms::runOnOperation() {
235235
return applyVectorTransferForwardingPatterns(getOperation());
236236
if (testGenericToVectorPattern)
237237
return applyLinalgToVectorPatterns(getOperation());
238-
if (testGeneralizePadTensor)
239-
return applyGeneralizePadTensorPatterns(getOperation());
238+
if (testDecomposePadTensor)
239+
return applyDecomposePadPatterns(getOperation());
240240
if (testDecomposeTensorPackOp)
241241
return applyDecomposeTensorPackPatterns(getOperation());
242242
if (testDecomposeTensorUnPackOp)

0 commit comments

Comments
 (0)