Skip to content

[mlir][Transforms] Add a PadTilingInterface transformation and hook i… #144991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,85 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
}];
}

//===----------------------------------------------------------------------===//
// PadTilingInterfaceOp
//===----------------------------------------------------------------------===//

def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interface",
[FunctionalStyleTransformOpTrait, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Pads the operations pointed to by the target handle using the options
provided as operation attributes. The operation returns a handle to the
padded operation and to the padding operation ("tensor.pad").

TODO: in the future this should be moved out of a specific Linalg
implementation file and into a more general "Structured" file.

#### Return modes

This operation ignores non-Linalg ops and drops them in the return.
In the future, this operation will support all TilingInterfaceOps.

This operation may produce a definite failure if the padding fails for any
reason.

If all the operations referred to by the `target` handle pad properly, the
transform succeeds. Otherwise the transform produces a silenceable failure.
The return handle points to only the subset of successfully produced
padded operations, which can be empty.
}];

let arguments =
(ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
Variadic<TransformAnyParamTypeOrAnyHandle>:$padding_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
$static_padding_sizes,
DefaultValuedAttr<UnitAttr, "false">:$pad_to_multiple_of);
let results = (outs TransformHandleTypeInterface:$padded,
TransformHandleTypeInterface:$pad);

let assemblyFormat = [{
$target
`to`
(`padding_sizes` custom<DynamicIndexList>($padding_sizes, $static_padding_sizes)^)?
(`pad_to_multiple_of` $pad_to_multiple_of^)?
attr-dict
`:` functional-type(operands, results)
}];

let hasVerifier = 1;

let builders = [
// Builder for a transform::PadOp with automatic inference of padding
// value. Warning: this will set the value 0 for the inferred elemental
// type without taking the op into account and thus only work for the
// add/mul ring at the moment.
// TODO: support other operations (e.g. min, max etc).
OpBuilder<(ins "Value":$target,
"ArrayRef<int64_t>":$paddingDimensions,
CArg<"ArrayRef<int64_t>", "{}">:$staticPaddingSizes,
CArg<"bool", "false">:$padToMultipleOf)>,
OpBuilder<(ins "Value":$target,
"ArrayRef<int64_t>":$paddingDimensions,
"ArrayRef<OpFoldResult>":$mixedPadPaddingSizes,
CArg<"bool", "false">:$usePrescribedTensorShapes)>
];

let extraClassDeclaration = [{
/// Returns a mix of dynamic `padding_sizes` and static `static_padding_sizes`.
SmallVector<OpFoldResult> getMixedPaddingSizes();

::mlir::DiagnosedSilenceableFailure apply(
::mlir::transform::TransformRewriter &rewriter,
::mlir::transform::TransformResults &results,
::mlir::transform::TransformState &state);
}];
}

//===----------------------------------------------------------------------===//
// HoistPadOp
//===----------------------------------------------------------------------===//
Expand Down
79 changes: 76 additions & 3 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/DialectConversion.h"
Expand Down Expand Up @@ -347,6 +348,34 @@ struct LinalgPaddingOptions {
}
};

struct PadTilingInterfaceOptions {
/// A padding value for every operand.
SmallVector<Attribute> paddingValues;
PadTilingInterfaceOptions &setPaddingValues(ArrayRef<Attribute> pv) {
paddingValues.assign(pv.begin(), pv.end());
return *this;
}
/// A list of iterator dimensions to pad.
SmallVector<int64_t> paddingDimensions;
PadTilingInterfaceOptions &setPaddingDimensions(ArrayRef<int64_t> pd) {
paddingDimensions.assign(pd.begin(), pd.end());
return *this;
}
/// A list of iterator dimensions sizes to pad to.
SmallVector<OpFoldResult> paddingSizes;
PadTilingInterfaceOptions &setPaddingSizes(ArrayRef<OpFoldResult> m) {
paddingSizes.assign(m.begin(), m.end());
return *this;
}
/// Pad iterator `paddingDimension[i]` to next multiple of `paddingSizes[i]`
/// if true. Otherwise pad to `paddingSizes[i]`.
bool padToMultipleOf;
PadTilingInterfaceOptions &setPadToMultipleOf(bool b) {
padToMultipleOf = b;
return *this;
}
};

/// Callback function type used to perform the allocation for the promoted
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
/// smallest constant value for the size of the buffer needed for each
Expand Down Expand Up @@ -542,9 +571,9 @@ SmallVector<Value> peelLoop(RewriterBase &rewriter, Operation *op);
/// where relevant.
void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops);

/// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands
/// to a static bounding box. The original `opToPad` is cloned and operates on
/// the padded tensors.
/// Pad the iterator dimensions `options.paddingDimensions` of all `opToPad`
/// operands to a static bounding box. The original `opToPad` is cloned and
/// operates on the padded tensors.
///
/// * "options.padToMultipleOf" indicates that each padding dimension should be
/// padded to the specified multiple.
Expand All @@ -561,6 +590,50 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
SmallVector<Value> &replacements,
SmallVector<tensor::PadOp> &padOps);

/// Helper function to compute the padded shape of the given value `v` of
/// `RankedTensorType` given:
/// - the `indexingSizes` as a list of OpFoldResult.
/// - an `indexingMap` that encodes how the padded shape varies with
/// increases in `indexingSizes`.
/// The implementation iteratively combines increases from contributing using
/// affine.apply operations.
/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and
/// provides a gentle portability path for Linalg-like ops with affine maps.
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult>
computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
const PadTilingInterfaceOptions &options);

using PadSizeComputationFunction =
std::function<FailureOr<SmallVector<OpFoldResult>>(
RewriterBase &, OpOperand &, ArrayRef<Range>,
const PadTilingInterfaceOptions &)>;

/// Specific helper for Linalg ops.
FailureOr<SmallVector<OpFoldResult>>
computeLinalgPaddedShape(RewriterBase &rewriter, OpOperand &operandToPad,
ArrayRef<Range> iterationDomain,
const PadTilingInterfaceOptions &options);

/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
///
/// * "options.paddingSizes" indicates that each padding dimension should be
/// padded to the specified padding size.
/// * "options.padToMultipleOf" indicates that the paddingSizes should be
// interpreted as the bounding box (dynamic) value to pad to.
/// * Use "options.paddingValues" to set the padding value of the created
// tensor::PadOp.
/// * The tensor::PadOp is returned on success.

FailureOr<TilingInterface>
rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
const PadTilingInterfaceOptions &constOptions,
SmallVector<tensor::PadOp> &padOps,
PadSizeComputationFunction computePaddingSizeFun =
&computeLinalgPaddedShape);

namespace detail {

/// Helper struct to hold the results of building a packing loop nest.
Expand Down
161 changes: 161 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include <type_traits>

using namespace mlir;
Expand Down Expand Up @@ -2155,6 +2156,166 @@ LogicalResult transform::PadOp::verify() {
return success();
}

//===---------------------------------------------------------------------===//
// PadTilingInterfaceOp
//===---------------------------------------------------------------------===//

void transform::PadTilingInterfaceOp::build(OpBuilder &b,
OperationState &result,
Value target,
ArrayRef<int64_t> paddingDimensions,
ArrayRef<int64_t> paddingSizes,
bool padToMultipleOf) {
auto resultType = transform::AnyOpType::get(b.getContext());
return build(/*builder=*/b,
/*result=*/result,
/*types=*/TypeRange{resultType, resultType},
/*target=*/target,
/*paddingValues=*/ArrayAttr(), // let inference handle this
/*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
/*paddingSizes=*/ValueRange{},
/*paddingSizes=*/
(paddingSizes.empty() ? DenseI64ArrayAttr()
: b.getDenseI64ArrayAttr(paddingSizes)),
/*padToMultipleOf=*/
padToMultipleOf ? b.getUnitAttr() : nullptr);
}

void transform::PadTilingInterfaceOp::build(
OpBuilder &b, OperationState &result, Value target,
ArrayRef<int64_t> paddingDimensions,
ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
auto resultType = transform::AnyOpType::get(b.getContext());
SmallVector<int64_t> staticPaddingSizes;
SmallVector<Value> dynamicPaddingSizes;
dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
staticPaddingSizes);
return build(/*builder=*/b,
/*result=*/result,
/*types=*/TypeRange{resultType, resultType},
/*target=*/target,
/*paddingValues=*/ArrayAttr(), // let inference handle this
/*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
/*paddingSizes=*/dynamicPaddingSizes,
/*paddingSizes=*/staticPaddingSizes,
/*usePrescribedTensorShapes=*/padToMultipleOf);
}

void transform::PadTilingInterfaceOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects);
onlyReadsHandle(getPaddingSizesMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}

SmallVector<OpFoldResult>
transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
Builder b(getContext());
return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
}

DiagnosedSilenceableFailure
transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> paddedOps, padOps;

for (Operation *target : state.getPayloadOps(getTarget())) {
auto targetOp = dyn_cast<TilingInterface>(target);
if (!targetOp) {
auto diag = emitSilenceableError() << "expected TilingInterface target";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}

// Only Linalg ops for now, until TilingInterface exposes a loopsToOperand
// map / C++ APIs to compute the effect of padding on operands.
if (!isa<LinalgOp>(targetOp.getOperation())) {
auto diag = emitSilenceableError() << "only LinalgOp supported atm";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}

// Convert the padding values to attributes.
SmallVector<Attribute> paddingValues;
for (auto const &[untypedAttr, elementOrTensorType] :
llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
auto attr = dyn_cast<TypedAttr>(untypedAttr);
Type elementType = getElementTypeOrSelf(elementOrTensorType);
if (!attr) {
emitOpError("expects padding values to be typed attributes");
return DiagnosedSilenceableFailure::definiteFailure();
}
// Try to parse string attributes to obtain an attribute of element type.
if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
stringAttr, getContext(), elementType,
/*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
if (!parsedAttr || parsedAttr.getType() != elementType) {
auto diag = this->emitOpError("expects a padding that parses to ")
<< elementType << ", got " << attr;
diag.attachNote(targetOp.getLoc()) << "when applied to this op";
return DiagnosedSilenceableFailure::definiteFailure();
}
paddingValues.push_back(parsedAttr);
continue;
}
// Otherwise, add the attribute directly.
if (attr.getType() != elementType) {
auto diag = this->emitOpError("expects a padding value of type ")
<< elementType << ", got " << attr;
diag.attachNote(targetOp.getLoc()) << "when applied to this op";
return DiagnosedSilenceableFailure::definiteFailure();
}
paddingValues.push_back(attr);
}

// Set options.
TilingInterface paddedOp;
PadTilingInterfaceOptions options;
options.setPaddingValues(paddingValues)
.setPaddingDimensions(
extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions()))
.setPaddingSizes(getMixedPaddingSizes())
.setPadToMultipleOf(getPadToMultipleOf());

// Apply padding.
SmallVector<tensor::PadOp> newPadOps;
FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
newPadOps);
if (failed(maybePaddedOp)) {
auto diag = emitSilenceableError() << "failed to pad op";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}

// Set transform results.
paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
padOps.append(newPadOps.begin(), newPadOps.end());
}

results.set(cast<OpResult>(getPadded()), paddedOps);
results.set(cast<OpResult>(getPad()), padOps);
return DiagnosedSilenceableFailure::success();
}

LogicalResult transform::PadTilingInterfaceOp::verify() {
SmallVector<int64_t> paddingDimensions =
extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
if (any_of(paddingDimensions,
[](int64_t paddingDimension) { return paddingDimension < 0; })) {
return emitOpError() << "expects padding_dimensions to contain positive "
"integers, found "
<< getPaddingDimensions();
}
if (getMixedPaddingSizes().size() != paddingDimensions.size()) {
return emitOpError() << "expects as many multiples as padding_dimensions";
}
return success();
}

//===---------------------------------------------------------------------===//
// HoistPadOp
//===---------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
BlockPackMatmul.cpp
PackAndUnpackPatterns.cpp
Padding.cpp
PadTilingInterface.cpp
Promotion.cpp
RuntimeOpVerification.cpp
Specialize.cpp
Expand Down
Loading
Loading