Skip to content

Commit 7360ed0

Browse files
[mlir][transform] Drop redundant padding_dimensions spec from pad_tiling_interface (#145257)
This revision aligns padding specification in pad_tiling_interface to that of tiling specification. Dimensions that should be skipped are specified by "padding by 0". Trailing dimensions that are ignored are automatically completed to "pad to 0".
1 parent d9a99af commit 7360ed0

File tree

5 files changed

+58
-81
lines changed

5 files changed

+58
-81
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,17 +1195,29 @@ def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interfac
11951195
TransformOpInterface,
11961196
ReportTrackingListenerFailuresOpTrait]> {
11971197
let description = [{
1198-
Pads the operations pointed to by the target handle using the options
1199-
provided as operation attributes. The operation returns a handle to the
1200-
padded operation and to the padding operation ("tensor.pad").
1198+
Pads the **iteration domain** of the operations pointed to by the target
1199+
handle using the options provided as operation attributes. Padding the
1200+
iteration domain induces a padding of the operands that is consistent
1201+
across the op semantics and, unlike for simple elementwise ops, may not be
1202+
trivially deducible or specifiable on operands only (e.g. convolutions).
1203+
1204+
The specification of `padding_sizes` follows that of `tile_sizes` during
1205+
tiling: the value "0" on a particular iterator encode "no padding". Like in
1206+
the case of tiling, an automatic completion by 0 to the operation rank
1207+
occurs.
1208+
1209+
This transformation returns a handle to the padded operation and to the
1210+
padding operation ("tensor.pad").
12011211

12021212
TODO: in the future this should be moved out of a specific Linalg
12031213
implementation file and into a more general "Structured" file.
12041214

12051215
#### Return modes
12061216

1207-
This operation ignores non-Linalg ops and drops them in the return.
1208-
In the future, this operation will support all TilingInterfaceOps.
1217+
This operation ignores non-IndexingMapOpInterface ops and drops them in the
1218+
return. In the future, this operation will support all TilingInterfaceOps
1219+
for which the contract between iteration domain and operands can be
1220+
reified.
12091221

12101222
This operation may produce a definite failure if the padding fails for any
12111223
reason.
@@ -1219,7 +1231,6 @@ def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interfac
12191231
let arguments =
12201232
(ins TransformHandleTypeInterface:$target,
12211233
DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
1222-
DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
12231234
Variadic<TransformAnyParamTypeOrAnyHandle>:$padding_sizes,
12241235
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
12251236
$static_padding_sizes,
@@ -1245,11 +1256,9 @@ def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interfac
12451256
// add/mul ring at the moment.
12461257
// TODO: support other operations (e.g. min, max etc).
12471258
OpBuilder<(ins "Value":$target,
1248-
"ArrayRef<int64_t>":$paddingDimensions,
12491259
CArg<"ArrayRef<int64_t>", "{}">:$staticPaddingSizes,
12501260
CArg<"bool", "false">:$padToMultipleOf)>,
12511261
OpBuilder<(ins "Value":$target,
1252-
"ArrayRef<int64_t>":$paddingDimensions,
12531262
"ArrayRef<OpFoldResult>":$mixedPadPaddingSizes,
12541263
CArg<"bool", "false">:$usePrescribedTensorShapes)>
12551264
];

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2163,7 +2163,6 @@ LogicalResult transform::PadOp::verify() {
21632163
void transform::PadTilingInterfaceOp::build(OpBuilder &b,
21642164
OperationState &result,
21652165
Value target,
2166-
ArrayRef<int64_t> paddingDimensions,
21672166
ArrayRef<int64_t> paddingSizes,
21682167
bool padToMultipleOf) {
21692168
auto resultType = transform::AnyOpType::get(b.getContext());
@@ -2172,7 +2171,6 @@ void transform::PadTilingInterfaceOp::build(OpBuilder &b,
21722171
/*types=*/TypeRange{resultType, resultType},
21732172
/*target=*/target,
21742173
/*paddingValues=*/ArrayAttr(), // let inference handle this
2175-
/*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
21762174
/*paddingSizes=*/ValueRange{},
21772175
/*paddingSizes=*/
21782176
(paddingSizes.empty() ? DenseI64ArrayAttr()
@@ -2183,7 +2181,6 @@ void transform::PadTilingInterfaceOp::build(OpBuilder &b,
21832181

21842182
void transform::PadTilingInterfaceOp::build(
21852183
OpBuilder &b, OperationState &result, Value target,
2186-
ArrayRef<int64_t> paddingDimensions,
21872184
ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
21882185
auto resultType = transform::AnyOpType::get(b.getContext());
21892186
SmallVector<int64_t> staticPaddingSizes;
@@ -2195,7 +2192,6 @@ void transform::PadTilingInterfaceOp::build(
21952192
/*types=*/TypeRange{resultType, resultType},
21962193
/*target=*/target,
21972194
/*paddingValues=*/ArrayAttr(), // let inference handle this
2198-
/*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
21992195
/*paddingSizes=*/dynamicPaddingSizes,
22002196
/*paddingSizes=*/staticPaddingSizes,
22012197
/*usePrescribedTensorShapes=*/padToMultipleOf);
@@ -2277,8 +2273,6 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
22772273
TilingInterface paddedOp;
22782274
PadTilingInterfaceOptions options;
22792275
options.setPaddingValues(paddingValues)
2280-
.setPaddingDimensions(
2281-
extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions()))
22822276
.setPaddingSizes(getMixedPaddingSizes())
22832277
.setPadToMultipleOf(getPadToMultipleOf());
22842278

@@ -2303,20 +2297,7 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
23032297
return DiagnosedSilenceableFailure::success();
23042298
}
23052299

2306-
LogicalResult transform::PadTilingInterfaceOp::verify() {
2307-
SmallVector<int64_t> paddingDimensions =
2308-
extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2309-
if (any_of(paddingDimensions,
2310-
[](int64_t paddingDimension) { return paddingDimension < 0; })) {
2311-
return emitOpError() << "expects padding_dimensions to contain positive "
2312-
"integers, found "
2313-
<< getPaddingDimensions();
2314-
}
2315-
if (getMixedPaddingSizes().size() != paddingDimensions.size()) {
2316-
return emitOpError() << "expects as many multiples as padding_dimensions";
2317-
}
2318-
return success();
2319-
}
2300+
LogicalResult transform::PadTilingInterfaceOp::verify() { return success(); }
23202301

23212302
//===---------------------------------------------------------------------===//
23222303
// HoistPadOp

mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,29 +32,27 @@ using namespace mlir::tensor;
3232
#define DBGSNL() (llvm::dbgs() << "\n")
3333

3434
/// Form a "full-rank" padding specification so that the application is easy.
35-
static llvm::SmallDenseMap<int64_t, OpFoldResult>
36-
getDimsToSize(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
37-
const PadTilingInterfaceOptions &options) {
38-
llvm::SmallDenseMap<int64_t, OpFoldResult> dimsToSize;
39-
for (const auto &[paddingDim, paddingSize] :
40-
llvm::zip_equal(options.paddingDimensions, options.paddingSizes)) {
41-
dimsToSize[paddingDim] = paddingSize;
42-
}
35+
static SmallVector<OpFoldResult>
36+
getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
37+
const PadTilingInterfaceOptions &options) {
38+
SmallVector<OpFoldResult> paddingSizes;
4339
// Complete the padding specification to specify all dimensions.
44-
for (int64_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) {
45-
if (dimsToSize.find(idx) != dimsToSize.end())
46-
continue;
47-
// If a dimension is not specified, either complete with:
40+
for (size_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) {
41+
// Complete to zero if needed.
42+
paddingSizes.push_back(options.paddingSizes.size() > idx
43+
? options.paddingSizes[idx]
44+
: b.getIndexAttr(0));
45+
// If a dimension is zero (either specified or completed), replace by:
4846
// - 1 if we are padding to the next multiple of.
4947
// - indexingSizes[idx] otherwise
50-
dimsToSize[idx] =
51-
options.padToMultipleOf ? b.getIndexAttr(1) : indexingSizes[idx];
52-
}
53-
for (int64_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) {
54-
LLVM_DEBUG(DBGS() << "----idx: " << idx << " : " << dimsToSize[idx]
48+
if (isZeroInteger(paddingSizes[idx])) {
49+
paddingSizes[idx] =
50+
options.padToMultipleOf ? b.getIndexAttr(1) : indexingSizes[idx];
51+
}
52+
LLVM_DEBUG(DBGS() << "----idx: " << idx << " : " << paddingSizes[idx]
5553
<< "\n");
5654
}
57-
return dimsToSize;
55+
return paddingSizes;
5856
}
5957

6058
/// Compute the padded shape of the given value `v` of `RankedTensorType` given
@@ -80,8 +78,8 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
8078
"rank");
8179

8280
// "Full-rank" padding specification.
83-
llvm::SmallDenseMap<int64_t, OpFoldResult> dimsToSize =
84-
getDimsToSize(rewriter, indexingSizes, options);
81+
SmallVector<OpFoldResult> paddingSizes =
82+
getFullRankPaddingSizes(rewriter, indexingSizes, options);
8583

8684
// For each dimension in the operand's shape, iterate over indexingSizes and
8785
// add the various term contributions.
@@ -97,7 +95,9 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
9795
// Find all padding dimensions that contribute to this operand dimension
9896
// and compute the padded term contribution to the final padded shape.
9997
SmallVector<OpFoldResult> terms;
100-
for (const auto &[paddingDim, paddingSize] : dimsToSize) {
98+
for (size_t paddingDim = 0, e = paddingSizes.size(); paddingDim != e;
99+
++paddingDim) {
100+
OpFoldResult paddingSize = paddingSizes[paddingDim];
101101
LLVM_DEBUG(DBGS() << "------try apply padding of dim: " << paddingDim
102102
<< " to: " << paddingSize << "\n");
103103
if (!enResults.value().isFunctionOfDim(paddingDim))
@@ -224,9 +224,6 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
224224
SmallVector<tensor::PadOp> &padOps,
225225
PadSizeComputationFunction computePaddingSizeFun) {
226226
LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
227-
assert(constOptions.paddingSizes.size() ==
228-
constOptions.paddingDimensions.size() &&
229-
"invalid number of elements in padToMultipleOf");
230227

231228
Location loc = opToPad.getLoc();
232229
PadTilingInterfaceOptions options(constOptions);

mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ module attributes {transform.with_named_sequence} {
3636
// Tile to 5 then pad to 8 (supposedly to better hit vector ops).
3737
%matmul_l1, %loops_l1 = transform.structured.tile_using_for %matmul tile_sizes [5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
3838
%matmul_padded, %_ = transform.structured.pad_tiling_interface %matmul_l1 to padding_sizes [8] pad_to_multiple_of {
39-
padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32],
40-
padding_dimensions=[0]
39+
padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32]
4140
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
4241

4342
transform.yield
@@ -71,11 +70,10 @@ module {
7170
return %0 : tensor<7x11x12xf32>
7271
}
7372
module attributes {transform.with_named_sequence} {
74-
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
75-
%0 = transform.structured.match ops{["linalg.generic"]} in %module_op : (!transform.any_op) -> !transform.any_op
76-
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [3, 5] pad_to_multiple_of {
77-
padding_dimensions = [0, 2],
78-
padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
73+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
74+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
75+
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [3, 0, 5] pad_to_multiple_of {
76+
padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32]
7977
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
8078
transform.yield
8179
}
@@ -126,11 +124,10 @@ module {
126124
return %0 : tensor<?x11x?xf32>
127125
}
128126
module attributes {transform.with_named_sequence} {
129-
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
130-
%0 = transform.structured.match ops{["linalg.generic"]} in %module_op : (!transform.any_op) -> !transform.any_op
131-
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [3, 5] pad_to_multiple_of {
132-
padding_dimensions = [0, 2],
133-
padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
127+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
128+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
129+
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [3, 0, 5] pad_to_multiple_of {
130+
padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32]
134131
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
135132
transform.yield
136133
}
@@ -172,9 +169,8 @@ module attributes {transform.with_named_sequence} {
172169
: (!transform.any_op) -> !transform.any_op
173170

174171
// Pad then tile should produce static shapes.
175-
%matmul_padded, %_ = transform.structured.pad_tiling_interface %matmul to padding_sizes [8, 16] pad_to_multiple_of {
176-
padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32],
177-
padding_dimensions=[0, 2]
172+
%matmul_padded, %_ = transform.structured.pad_tiling_interface %matmul to padding_sizes [8, 0, 16] pad_to_multiple_of {
173+
padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32]
178174
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
179175

180176
%m, %l0, %l1 = transform.structured.tile_using_for %matmul_padded tile_sizes [8, 0, 16]
@@ -234,9 +230,8 @@ module attributes {transform.with_named_sequence} {
234230
%m, %l0, %l1 = transform.structured.tile_using_for %matmul tile_sizes [8, 0, 16]
235231
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
236232

237-
%matmul_padded, %_ = transform.structured.pad_tiling_interface %m to padding_sizes [8, 16] pad_to_multiple_of {
238-
padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32],
239-
padding_dimensions=[0, 2]
233+
%matmul_padded, %_ = transform.structured.pad_tiling_interface %m to padding_sizes [8, 0, 16] pad_to_multiple_of {
234+
padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32]
240235
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
241236

242237
transform.yield
@@ -269,9 +264,8 @@ module attributes {transform.with_named_sequence} {
269264
%m, %l0, %l1 = transform.structured.tile_using_for %matmul tile_sizes [8, 0, 16]
270265
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
271266

272-
%matmul_padded, %_ = transform.structured.pad_tiling_interface %m to padding_sizes [8, 16] pad_to_multiple_of {
273-
padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32],
274-
padding_dimensions=[0, 2]
267+
%matmul_padded, %_ = transform.structured.pad_tiling_interface %m to padding_sizes [8, 0, 16] pad_to_multiple_of {
268+
padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32]
275269
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
276270

277271
transform.yield

mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ module attributes {transform.with_named_sequence} {
1818
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
1919

2020
%fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
21-
padding_values=[0.0 : f32, 0.0 : f32],
22-
padding_dimensions=[0]
21+
padding_values=[0.0 : f32, 0.0 : f32]
2322
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
2423

2524
transform.yield
@@ -55,8 +54,7 @@ module attributes {transform.with_named_sequence} {
5554
// Tile to 5 then pad to 8 (supposedly to better hit vector ops).
5655
%matmul_l1, %loops_l1 = transform.structured.tile_using_for %matmul tile_sizes [5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
5756
%matmul_padded, %_ = transform.structured.pad_tiling_interface %matmul_l1 to padding_sizes [8] {
58-
padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32],
59-
padding_dimensions=[0]
57+
padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32]
6058
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
6159

6260
transform.yield
@@ -91,8 +89,7 @@ module {
9189
module attributes {transform.with_named_sequence} {
9290
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
9391
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
94-
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 14] {
95-
padding_dimensions = [0, 2],
92+
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] {
9693
padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
9794
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
9895
transform.yield
@@ -147,8 +144,7 @@ module {
147144
module attributes {transform.with_named_sequence} {
148145
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
149146
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
150-
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 14] {
151-
padding_dimensions = [0, 2],
147+
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] {
152148
padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
153149
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
154150
transform.yield

0 commit comments

Comments
 (0)