Skip to content

Commit 6cc3bf7

Browse files
authored
[mlir][tensor] Add canonicalization to fold consecutive tensor.pad ops (#107302)
`tensor.pad(tensor.pad)` with the same constant padding value can be combined into a single pad that pads to the sum of the high and low padding amounts.
1 parent ea92045 commit 6cc3bf7

File tree

2 files changed

+161
-1
lines changed

2 files changed

+161
-1
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3402,12 +3402,90 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
34023402
}
34033403
};
34043404

3405+
/// Folds a chain of `tensor.pad` ops with the same constant padding value.
3406+
///
3407+
/// Example:
3408+
///
3409+
/// ```mlir
3410+
/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
3411+
/// tensor.yield %val
3412+
/// } : tensor<1x2xf32> to tensor<2x5xf32>
3413+
/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
3414+
/// tensor.yield %val
3415+
/// } : tensor<1x5xf32> to tensor<5x7xf32>
3416+
/// ```
3417+
///
3418+
/// folds into:
3419+
///
3420+
/// ```mlir
3421+
/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
3422+
/// tensor.yield %val
3423+
/// } : tensor<1x2xf32> to tensor<5x7xf32>
3424+
/// ```
3425+
struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
3426+
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3427+
3428+
LogicalResult matchAndRewrite(tensor::PadOp padOp,
3429+
PatternRewriter &rewriter) const override {
3430+
if (padOp.getNofold()) {
3431+
return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3432+
}
3433+
3434+
auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3435+
if (!producerPad || producerPad.getNofold()) {
3436+
return rewriter.notifyMatchFailure(
3437+
padOp, "producer is not a foldable tensor.pad op");
3438+
}
3439+
3440+
// Fail if the tensor::PadOps padding values do not match.
3441+
Value consumerPadValue = padOp.getConstantPaddingValue();
3442+
Value producerPadValue = producerPad.getConstantPaddingValue();
3443+
if (!consumerPadValue || !producerPadValue ||
3444+
consumerPadValue != producerPadValue) {
3445+
return rewriter.notifyMatchFailure(
3446+
padOp,
3447+
"cannot fold PadOps with different or non-constant padding values");
3448+
}
3449+
3450+
Location loc = padOp.getLoc();
3451+
AffineExpr d0, d1;
3452+
bindDims(rewriter.getContext(), d0, d1);
3453+
3454+
// Combine the low/high paddings of the two tensor::PadOps.
3455+
auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3456+
ArrayRef<OpFoldResult> producerPaddings) {
3457+
SmallVector<OpFoldResult> sumPaddings;
3458+
for (auto [consumerIndex, producerIndex] :
3459+
llvm::zip_equal(consumerPaddings, producerPaddings)) {
3460+
sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3461+
rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3462+
}
3463+
return sumPaddings;
3464+
};
3465+
3466+
SmallVector<OpFoldResult> newHighPad =
3467+
addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3468+
SmallVector<OpFoldResult> newLowPad =
3469+
addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3470+
3471+
auto newPadOp = rewriter.create<tensor::PadOp>(
3472+
padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
3473+
newLowPad, newHighPad, padOp.getNofold(),
3474+
getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
3475+
rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
3476+
newPadOp.getRegion().begin());
3477+
rewriter.replaceOp(padOp, newPadOp.getResult());
3478+
return success();
3479+
}
3480+
};
3481+
34053482
} // namespace
34063483

34073484
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
34083485
MLIRContext *context) {
34093486
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3410-
FoldOrthogonalPaddings, FoldStaticPadding>(context);
3487+
FoldOrthogonalPaddings, FoldStaticPadding,
3488+
FoldConsecutiveConstantPadding>(context);
34113489
}
34123490

34133491
/// Return the padding value of the PadOp if it constant. In this context,

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,6 +1964,88 @@ func.func @dont_fold_pad_chains(%arg0: tensor<64x64xf32>,
19641964

19651965
// -----
19661966

1967+
// CHECK-LABEL: func @merge_constant_padding
1968+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<2x3xf32>
1969+
// CHECK-SAME: %[[PADVAL:[A-Za-z0-9]+]]: f32
1970+
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[1, 3] high[4, 2]
1971+
// CHECK: tensor.yield %[[PADVAL]]
1972+
// CHECK: return %[[PAD]]
1973+
func.func @merge_constant_padding(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
1974+
%pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
1975+
^bb0(%b0: index, %b1 : index):
1976+
tensor.yield %pad_value : f32
1977+
} : tensor<2x3xf32> to tensor<4x4xf32>
1978+
%pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
1979+
^bb0(%b2: index, %b3 : index):
1980+
tensor.yield %pad_value : f32
1981+
} : tensor<4x4xf32> to tensor<7x8xf32>
1982+
return %pad1 : tensor<7x8xf32>
1983+
}
1984+
1985+
// -----
1986+
1987+
// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 + 1)>
1988+
// CHECK-LABEL: func @merge_constant_padding_dynamic
1989+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<?x?xf32>
1990+
// CHECK-SAME: %[[IDX:[A-Za-z0-9]+]]: index
1991+
// CHECK-SAME: %[[PADVAL:[A-Za-z0-9]+]]: f32
1992+
// CHECK: %[[HIGH:.+]] = affine.apply #[[$MAP]]()[%[[IDX]]]
1993+
// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[%[[IDX]], 3] high[%[[HIGH]], 2]
1994+
// CHECK: tensor.yield %[[PADVAL]]
1995+
// CHECK: return %[[PAD]]
1996+
func.func @merge_constant_padding_dynamic(%arg0: tensor<?x?xf32>, %idx: index, %pad_value: f32) -> tensor<?x?xf32> {
1997+
%pad0 = tensor.pad %arg0 low[%idx, 1] high[1, 0] {
1998+
^bb0(%b0: index, %b1 : index):
1999+
tensor.yield %pad_value : f32
2000+
} : tensor<?x?xf32> to tensor<?x?xf32>
2001+
%pad1 = tensor.pad %pad0 low[0, 2] high[%idx, 2] {
2002+
^bb0(%b2: index, %b3 : index):
2003+
tensor.yield %pad_value : f32
2004+
} : tensor<?x?xf32> to tensor<?x?xf32>
2005+
return %pad1 : tensor<?x?xf32>
2006+
}
2007+
2008+
// -----
2009+
2010+
// Verify that folding does not happen if it would drop a nofold attribute
2011+
// CHECK-LABEL: func @dont_merge_constant_padding_nofold
2012+
// CHECK: tensor.pad {{.*}} nofold
2013+
// CHECK: tensor.pad
2014+
func.func @dont_merge_constant_padding_nofold(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
2015+
%pad0 = tensor.pad %arg0 nofold low[1, 1] high[1, 0] {
2016+
^bb0(%b0: index, %b1 : index):
2017+
tensor.yield %pad_value : f32
2018+
} : tensor<2x3xf32> to tensor<4x4xf32>
2019+
%pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
2020+
^bb0(%b2: index, %b3 : index):
2021+
tensor.yield %pad_value : f32
2022+
} : tensor<4x4xf32> to tensor<7x8xf32>
2023+
return %pad1 : tensor<7x8xf32>
2024+
}
2025+
2026+
// -----
2027+
2028+
// Verify that folding does not happen if it would drop a nofold attribute
2029+
// CHECK-LABEL: func @dont_merge_constant_padding_different_vals
2030+
// CHECK: tensor.pad
2031+
// CHECK: tensor.pad
2032+
func.func @dont_merge_constant_padding_different_vals(
2033+
%arg0: tensor<2x3xf32>,
2034+
%pad_value0: f32,
2035+
%pad_value1: f32) -> tensor<7x8xf32> {
2036+
%pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
2037+
^bb0(%b0: index, %b1 : index):
2038+
tensor.yield %pad_value0 : f32
2039+
} : tensor<2x3xf32> to tensor<4x4xf32>
2040+
%pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
2041+
^bb0(%b2: index, %b3 : index):
2042+
tensor.yield %pad_value1 : f32
2043+
} : tensor<4x4xf32> to tensor<7x8xf32>
2044+
return %pad1 : tensor<7x8xf32>
2045+
}
2046+
2047+
// -----
2048+
19672049
// CHECK-LABEL: func @fold_collapse_shape_from_elements
19682050
func.func @fold_collapse_shape_from_elements(%arg0: i32) -> tensor<i32> {
19692051
// CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<i32>

0 commit comments

Comments
 (0)