Skip to content

Commit c9efeda

Browse files
committed
fixup! [mlir] Add apply_patterns.linalg.pad_vectorization TD Op
Address comment from Javed
1 parent 17bc7a9 commit c9efeda

File tree

3 files changed

+23
-15
lines changed

3 files changed

+23
-15
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,17 @@ def ApplyPadVectorizationPatternsOp : Op<Transform_Dialect,
8888
"apply_patterns.linalg.pad_vectorization",
8989
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
9090
let description = [{
91-
Apply patterns that take tensor.pad and rewrites it as
92-
vector.transfer_read/vector.transfer_write Ops.
93-
94-
These patterns will either fold tensor.pad with an existing
95-
vector.transfer_read or vector.transfer_write producer/consumers (requires
96-
other surrounding Ops to be already vectorised) or rewrite it, together
97-
with tensor.insert_slice consumer, as a vector.transfer_read +
98-
vector.transfer_write pair.
91+
Apply patterns that vectorize tensor.pad.
92+
93+
These patterns rewrite tensor.pad Ops using vector.transfer_read and
94+
vector.transfer_write operations. This is done either by:
95+
1. Folding tensor.pad with an existing vector.transfer_read /
96+
vector.transfer_write Op (generated prior to running these patterns).
97+
2. Rewriting it (when matched together with q tensor.insert_slice
98+
consumer Op) as a vector.transfer_read + vector.transfer_write pair.
99+
100+
In both cases, these patterns look at producers and consumers for the
101+
matched tensor.pad Op to find opportunities for vectorization.
99102
}];
100103

101104
let assemblyFormat = "attr-dict";

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2304,7 +2304,7 @@ static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
23042304
return result;
23052305
}
23062306

2307-
/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp/GenerateOp and
2307+
/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
23082308
/// InsertSliceOp. For now, only constant padding values are supported.
23092309
/// If there is enough static type information, TransferReadOps and
23102310
/// TransferWriteOps may be generated instead of InsertSliceOps.
@@ -2712,6 +2712,9 @@ struct PadOpVectorizationWithInsertSlicePattern
27122712

27132713
void mlir::linalg::populatePadOpVectorizationPatterns(
27142714
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2715+
// TODO: The following pattern implements "decomposition" and
2716+
// optional "vectorization". Seperate "decomposition" into a sepereate
2717+
// pre-processing pattern group.
27152718
patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
27162719
baseBenefit);
27172720
// Try these specialized patterns first before resorting to the generic one.

mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,18 @@ module attributes {transform.with_named_sequence} {
3737
// -----
3838

3939
///----------------------------------------------------------------------------------------
40-
/// [Pattern: PadOpVectorizationWithTransferReadPattern
40+
/// [Pattern: PadOpVectorizationWithTransferWritePattern]
4141
///----------------------------------------------------------------------------------------
4242
func.func private @make_vector() -> vector<7x9xf32>
4343

44-
// CHECK-LABEL: func @pad_and_transfer_write_static
44+
// CHECK-LABEL: func @pad_and_transfer_write_static_low_and_high
4545
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
4646
// CHECK-NOT: tensor.pad
4747
// CHECK: %[[C0:.*]] = arith.constant 0 : index
4848
// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32>
4949
// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[ARG0]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<5x6xf32>
5050
// CHECK: return %[[RESULT]]
51-
func.func @pad_and_transfer_write_static(
51+
func.func @pad_and_transfer_write_static_low_and_high(
5252
%arg0: tensor<5x6xf32>) -> tensor<5x6xf32> {
5353
%c0 = arith.constant 0 : index
5454
%c5 = arith.constant 5.0 : f32
@@ -78,15 +78,15 @@ module attributes {transform.with_named_sequence} {
7878

7979
func.func private @make_vector() -> vector<7x9xf32>
8080

81-
// CHECK-LABEL: func @pad_and_transfer_write_dynamic_static
81+
// CHECK-LABEL: func @pad_and_transfer_write_static_low_dynamic_high
8282
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf32>, %[[SIZE:.*]]: index, %[[PADDING:.*]]: index
8383
// CHECK-NOT: tensor.pad
8484
// CHECK: %[[C0:.*]] = arith.constant 0 : index
8585
// CHECK: %[[SUB:.*]] = tensor.extract_slice %[[ARG0]][0, 0] [%[[SIZE]], 6] [1, 1] : tensor<?x?xf32> to tensor<?x6xf32>
8686
// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32>
8787
// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[SUB]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<?x6xf32>
8888
// CHECK: return %[[RESULT]]
89-
func.func @pad_and_transfer_write_dynamic_static(
89+
func.func @pad_and_transfer_write_static_low_dynamic_high(
9090
%arg0: tensor<?x?xf32>, %size: index, %padding: index) -> tensor<?x6xf32> {
9191
%c0 = arith.constant 0 : index
9292
%c5 = arith.constant 5.0 : f32
@@ -166,7 +166,9 @@ module attributes {transform.with_named_sequence} {
166166

167167
func.func private @make_vector() -> tensor<12x13xf32>
168168

169-
// Same as @pad_and_insert_slice_dest in vectorization-wit-patterns.mlir, but
169+
// Same as @pad_and_insert_slice_dest in vectorization-with-patterns.mlir, but
170+
// over here linalg::fill is not vectorized (patterns for linalg.fill are not
171+
// included here)
170172
// CHECK-LABEL: func.func @pad_and_insert_slice_dest(
171173
// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x5x6xf32>) -> tensor<1x12x13xf32> {
172174
// CHECK-NOT: tensor.pad

0 commit comments

Comments
 (0)