Skip to content

Commit 97c1a24

Browse files
authored
[mlir][linalg] Add option to pad dynamic dims to linalg::rewriteAsPaddedOp (#144354)
This patch makes the following changes: - Add a `ValueRange typeDynDims` argument to `linalg::makeComposedPadHighOp`, allowing to pad a tensor with dynamic dimensions using `tensor::createPadHighOp`. - Add a `DenseMap<std::pair<unsigned, unsigned>, OpFoldResult> sizeToPadTo;` option to `LinalgPaddingOptions`. This option allows setting the size to use when padding a dimension of an operand, allowing to pad operands even in the case they don't have a constant upper bounding box. If the value is not provided, then the constant upper bound is used by default. - Add a `use_prescribed_tensor_shapes` option to `transform.structured.pad`. If set to true then `tensor.dim` will be used as dimensions to compute the size of the padded dim instead of computing the constant upper bound. - This patch also changes the behavior for computing the padded shape `linalg::rewriteAsPaddedOp`, by using the newly added options in `LinalgPaddingOptions`. - Finally it adds tests verifying the behavior.
1 parent c0a9c90 commit 97c1a24

File tree

7 files changed

+246
-60
lines changed

7 files changed

+246
-60
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,14 +1134,16 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
11341134
DefaultValuedAttr<
11351135
TypedArrayAttrBase<I64ArrayAttr, "array of arrays of i64">,
11361136
"{}">:$transpose_paddings,
1137-
DefaultValuedAttr<StrAttr, "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copy_back_op);
1137+
DefaultValuedAttr<StrAttr, "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copy_back_op,
1138+
DefaultValuedAttr<UnitAttr, "false">:$use_prescribed_tensor_shapes);
11381139
let results = (outs TransformHandleTypeInterface:$padded,
11391140
TransformHandleTypeInterface:$pad,
11401141
TransformHandleTypeInterface:$copy);
11411142

11421143
let assemblyFormat = [{
11431144
$target
11441145
(`pad_to_multiple_of` custom<DynamicIndexList>($pad_to_multiple_of, $static_pad_to_multiple_of)^)?
1146+
(`use_prescribed_tensor_shapes` $use_prescribed_tensor_shapes^)?
11451147
attr-dict
11461148
`:` functional-type(operands, results)
11471149
}];
@@ -1159,13 +1161,15 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
11591161
CArg<"ArrayRef<int64_t>", "{}">:$staticPadToMultipleOf,
11601162
CArg<"ArrayRef<int64_t>", "{}">:$nofoldFlags,
11611163
CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
1162-
CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>,
1164+
CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp,
1165+
CArg<"bool", "false">:$usePrescribedTensorShapes)>,
11631166
OpBuilder<(ins "Value":$target,
11641167
"ArrayRef<int64_t>":$paddingDimensions,
11651168
"ArrayRef<OpFoldResult>":$mixedPadToMultipleOf,
11661169
CArg<"ArrayRef<int64_t>", "{}">:$nofoldFlags,
11671170
CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
1168-
CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>
1171+
CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp,
1172+
CArg<"bool", "false">:$usePrescribedTensorShapes)>
11691173
];
11701174

11711175
let extraClassDeclaration = [{

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,23 @@ struct LinalgPaddingOptions {
295295
padToMultipleOf.emplace(m.begin(), m.end());
296296
return *this;
297297
}
298+
/// A mapping between an operand and shape dim, and a size for a padding
299+
/// dimension. Each size is expected to be greater or equal than the
300+
/// corresponding shape dim. If no value is provided then the constant upper
301+
/// bound will be used.
302+
DenseMap<std::pair<unsigned, unsigned>, OpFoldResult> sizeToPadTo;
303+
LinalgPaddingOptions &setSizeToPadTo(unsigned operandIndex, unsigned dimIndex,
304+
OpFoldResult size) {
305+
assert(size && "expected non-null size");
306+
sizeToPadTo[{operandIndex, dimIndex}] = size;
307+
return *this;
308+
}
309+
/// Given the operand index and shape dim it returns the size to pad to.
310+
OpFoldResult getSizeToPadTo(unsigned operandIndex, unsigned dimIndex) const {
311+
return sizeToPadTo.lookup_or(
312+
std::pair<unsigned, unsigned>(operandIndex, dimIndex), nullptr);
313+
}
314+
298315
/// A flag for every operand to mark the PadOp as nofold which enables
299316
/// packing for statically shaped operands.
300317
SmallVector<bool> nofoldFlags;

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,14 @@ bool isParallelIterator(utils::IteratorType iteratorType);
7171
/// Check if iterator type has "reduction" semantics.
7272
bool isReductionIterator(utils::IteratorType iteratorType);
7373

74-
/// Create a tensor::PadOp that pads `source` to the size of the statically
75-
/// sized `type` whose static sizes are assumed to be greater than the dynamic
76-
/// `source` size. The padding introduces trailing `pad` values until the
77-
/// target size is met. If `source` is defined by one or more LinalgOps that
78-
/// have been padded with the same value and sizes, return their padded result
79-
/// instead of creating a tensor::PadOp.
74+
/// Create a tensor::PadOp that pads `source` to the shape of `type` whose sizes
75+
/// are assumed to be greater than the dynamic `source` size. If `typeDynDims`
76+
/// is specified, then it must contain the sizes of all the dynamic dimensions
77+
/// in order of appearance in `type`, otherwise the function will pad those
78+
/// values to `0`. The padding introduces trailing `pad` values until the target
79+
/// size is met. If `source` is defined by one or more LinalgOps that have been
80+
/// padded with the same value and sizes, return their padded result instead of
81+
/// creating a tensor::PadOp.
8082
///
8183
/// Example:
8284
/// ```
@@ -91,7 +93,8 @@ bool isReductionIterator(utils::IteratorType iteratorType);
9193
/// %4 = tensor.pad %3 low[0, 0] high[...] { tensor.yield %other_cst }
9294
/// ```
9395
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
94-
Value source, Value pad, bool nofold);
96+
Value source, Value padding, bool nofold,
97+
ValueRange typeDynDims = std::nullopt);
9598

9699
/// Returns GenericOp that copies an n-D memref. Unlike the current
97100
/// implementation of memref::CopyOp, this op can further tile, lower to loops

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,7 +1907,8 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
19071907
ArrayRef<int64_t> padToMultipleOf,
19081908
ArrayRef<int64_t> nofoldFlags,
19091909
ArrayRef<Attribute> transposePaddings,
1910-
StringRef copyBackOp) {
1910+
StringRef copyBackOp,
1911+
bool usePrescribedTensorShapes) {
19111912
auto resultType = transform::AnyOpType::get(b.getContext());
19121913
return build(/*builder=*/b,
19131914
/*result=*/result,
@@ -1922,15 +1923,18 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
19221923
: b.getDenseI64ArrayAttr(padToMultipleOf)),
19231924
/*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
19241925
/*transposePaddings=*/b.getArrayAttr(transposePaddings),
1925-
/*copyBackOp=*/b.getStringAttr(copyBackOp));
1926+
/*copyBackOp=*/b.getStringAttr(copyBackOp),
1927+
/*usePrescribedTensorShapes=*/
1928+
usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
19261929
}
19271930

19281931
void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
19291932
ArrayRef<int64_t> paddingDimensions,
19301933
ArrayRef<OpFoldResult> mixedPadToMultipleOf,
19311934
ArrayRef<int64_t> nofoldFlags,
19321935
ArrayRef<Attribute> transposePaddings,
1933-
StringRef copyBackOp) {
1936+
StringRef copyBackOp,
1937+
bool usePrescribedTensorShapes) {
19341938
auto resultType = transform::AnyOpType::get(b.getContext());
19351939
SmallVector<int64_t> staticPadToMultipleOf;
19361940
SmallVector<Value> dynamicPadToMultipleOf;
@@ -1946,7 +1950,8 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
19461950
/*padToMultipleOf=*/staticPadToMultipleOf,
19471951
/*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
19481952
/*transposePaddings=*/b.getArrayAttr(transposePaddings),
1949-
/*copyBackOp=*/b.getStringAttr(copyBackOp));
1953+
/*copyBackOp=*/copyBackOp,
1954+
/*usePrescribedTensorShapes=*/usePrescribedTensorShapes);
19501955
}
19511956

19521957
void PadOp::getEffects(
@@ -2051,11 +2056,34 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
20512056
} else {
20522057
llvm_unreachable("unsupported copy_back op");
20532058
}
2059+
// Populate `sizeToPadTo` with the dynamic tensor sizes for each operand.
2060+
bool irChanged = false;
2061+
if (getUsePrescribedTensorShapes() &&
2062+
linalgTarget.hasPureTensorSemantics()) {
2063+
OpBuilder::InsertionGuard g(rewriter);
2064+
rewriter.setInsertionPoint(linalgTarget);
2065+
for (OpOperand &operand : linalgTarget->getOpOperands()) {
2066+
for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
2067+
if (!ShapedType::isDynamic(dim))
2068+
continue;
2069+
options.setSizeToPadTo(operand.getOperandNumber(), i,
2070+
tensor::getMixedSize(rewriter,
2071+
operand.get().getLoc(),
2072+
operand.get(), i));
2073+
irChanged = true;
2074+
}
2075+
}
2076+
}
20542077

20552078
SmallVector<Value> replacements;
20562079
SmallVector<tensor::PadOp> newPadOps;
20572080
if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
20582081
replacements, newPadOps))) {
2082+
if (irChanged) {
2083+
auto diag = emitDefiniteFailure() << "failed to pad op";
2084+
diag.attachNote(target->getLoc()) << "target op";
2085+
return diag;
2086+
}
20592087
auto diag = emitSilenceableError() << "failed to pad op";
20602088
diag.attachNote(target->getLoc()) << "target op";
20612089
return diag;

0 commit comments

Comments
 (0)