Skip to content

Commit 973dbe2

Browse files
author
gysit
committed
[mlir][tensor] Add pattern to fold ExtractSliceOp, PadOp chains.
The pattern folds chains of tensor::ExtractSliceOp, tensor::PadOp pairs if they pad different dimensions. Repeated tiling and padding of the tiled dimensions may introduce such chains. This canonicalization pattern folds these chains to a single tensor::ExtractSliceOp, tensor::PadOp pair that pads all dimensions at once, which simplifies vectorization and bufferization. Example: ```mlir %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor<?x64xf32> %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ... } : tensor<?x64xf32> to tensor<8x64xf32> %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32> %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ... } : tensor<8x?xf32> to tensor<8x4xf32> ``` folds into: ```mlir %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1] : tensor<64x64xf32> to tensor<?x?xf32> %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ... } : tensor<?x?xf32> to tensor<8x4xf32> ``` Reviewed By: nicolasvasilache, hanchung Differential Revision: https://reviews.llvm.org/D122722
1 parent 40ad667 commit 973dbe2

File tree

3 files changed

+263
-4
lines changed

3 files changed

+263
-4
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
678678
Tensor_Op<mnemonic, !listconcat(traits, [NoSideEffect])>,
679679
Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>,
680680
Results<(outs AnyTensor:$result)> {
681-
681+
682682
code commonExtraClassDeclaration = [{
683683
static StringRef getReassociationAttrName() { return "reassociation"; }
684684
SmallVector<AffineMap, 4> getReassociationMaps();
@@ -982,6 +982,8 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
982982
return getConstantIntValue(ofr) == static_cast<int64_t>(0);
983983
});
984984
}
985+
/// Return the dimensions with a non-zero low or high padding.
986+
llvm::SmallBitVector getPaddedDims();
985987
}];
986988

987989
let builders = [

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 171 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,6 +1858,18 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
18581858
result.addAttributes(attrs);
18591859
}
18601860

1861+
llvm::SmallBitVector PadOp::getPaddedDims() {
1862+
llvm::SmallBitVector paddedDims(getSourceType().getRank());
1863+
auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
1864+
for (const auto &en : enumerate(paddingWidths))
1865+
if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
1866+
paddedDims.set(en.index());
1867+
};
1868+
extractPaddedDims(getMixedLowPad());
1869+
extractPaddedDims(getMixedHighPad());
1870+
return paddedDims;
1871+
}
1872+
18611873
namespace {
18621874
// Folds tensor.pad when padding is static zeros and the attribute
18631875
// doesn't request otherwise.
@@ -1940,13 +1952,169 @@ struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
19401952
return success();
19411953
}
19421954
};
1955+
1956+
/// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
1957+
/// different dimensions. The pattern applies if the following preconditions
1958+
/// hold:
1959+
/// 1) the tensor::ExtractSliceOps are not rank-reducing,
1960+
/// 2) the tensor::ExtractSliceOps have only unit-strides,
1961+
/// 3) the tensor::PadOps perform only high-padding,
1962+
/// 4) the tensor::PadOps have the same constant padding value,
1963+
/// 5) the tensor::PadOps do not have common padding dimensions,
1964+
/// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
1965+
/// zero-offset for every dimension.
1966+
/// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for the
1967+
/// padded source dimensions.
1968+
///
1969+
/// Example:
1970+
///
1971+
/// ```mlir
1972+
/// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
1973+
/// : tensor<64x64xf32> to tensor<?x64xf32>
1974+
/// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
1975+
/// } : tensor<?x64xf32> to tensor<8x64xf32>
1976+
/// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
1977+
/// : tensor<8x64xf32> to tensor<8x?xf32>
1978+
/// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
1979+
/// } : tensor<8x?xf32> to tensor<8x4xf32>
1980+
/// ```
1981+
///
1982+
/// folds into:
1983+
///
1984+
/// ```mlir
1985+
/// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
1986+
/// : tensor<64x64xf32> to tensor<?x?xf32>
1987+
/// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
1988+
/// } : tensor<?x?xf32> to tensor<8x4xf32>
1989+
/// ```
1990+
struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
1991+
using OpRewritePattern<PadOp>::OpRewritePattern;
1992+
1993+
LogicalResult matchAndRewrite(PadOp padOp,
1994+
PatternRewriter &rewriter) const override {
1995+
auto innerSliceOp = padOp.source().getDefiningOp<ExtractSliceOp>();
1996+
if (!innerSliceOp)
1997+
return failure();
1998+
auto outerPadOp = innerSliceOp.source().getDefiningOp<PadOp>();
1999+
if (!outerPadOp || outerPadOp.nofold())
2000+
return failure();
2001+
auto outerSliceOp = outerPadOp.source().getDefiningOp<ExtractSliceOp>();
2002+
if (!outerSliceOp)
2003+
return failure();
2004+
2005+
// 1) Fail if the chain is rank-reducing.
2006+
int64_t rank = padOp.getSourceType().getRank();
2007+
if (outerSliceOp.getSourceType().getRank() != rank) {
2008+
return rewriter.notifyMatchFailure(padOp,
2009+
"cannot fold rank-reducing chain");
2010+
}
2011+
2012+
// 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
2013+
if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
2014+
return rewriter.notifyMatchFailure(
2015+
padOp, "cannot fold non-unit stride ExtractSliceOps");
2016+
}
2017+
2018+
// 3) Fail if the tensor::PadOps have non-zero low padding.
2019+
if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
2020+
return rewriter.notifyMatchFailure(padOp,
2021+
"cannot fold PadOps with low padding");
2022+
}
2023+
2024+
// 4) Fail if the tensor::PadOps padding values do not match.
2025+
Attribute innerAttr, outerAttr;
2026+
Value innerValue = padOp.getConstantPaddingValue();
2027+
Value outerValue = outerPadOp.getConstantPaddingValue();
2028+
if (!innerValue || !outerValue ||
2029+
!matchPattern(innerValue, m_Constant(&innerAttr)) ||
2030+
!matchPattern(outerValue, m_Constant(&outerAttr)) ||
2031+
innerAttr != outerAttr) {
2032+
return rewriter.notifyMatchFailure(
2033+
padOp, "cannot fold PadOps with different padding values");
2034+
}
2035+
2036+
// 5) Fail if a dimension is padded by both tensor::PadOps.
2037+
llvm::SmallBitVector innerDims = padOp.getPaddedDims();
2038+
llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
2039+
if (innerDims.anyCommon(outerDims)) {
2040+
return rewriter.notifyMatchFailure(
2041+
padOp, "cannot fold PadOps with common padding dimensions");
2042+
}
2043+
2044+
// 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
2045+
// zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2046+
// for every dimension, and use the offset the other pair. Fail if no
2047+
// zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2048+
// exists.
2049+
SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
2050+
for (auto &en : enumerate(newOffsets)) {
2051+
OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
2052+
OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
2053+
if (!innerDims.test(en.index()) &&
2054+
(getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
2055+
en.value() = outerOffset;
2056+
continue;
2057+
}
2058+
if (!outerDims.test(en.index()) &&
2059+
(getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
2060+
en.value() = innerOffset;
2061+
continue;
2062+
}
2063+
return rewriter.notifyMatchFailure(
2064+
padOp, "cannot find zero-offset and zero-padding pair");
2065+
}
2066+
2067+
// 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size of
2068+
// the outer tensor::ExtractSliceOp for the dimensions padded by the outer
2069+
// tensor::PadOp and fail if the size of the inner tensor::ExtractSliceOp
2070+
// does not match the size of the padded dimension. Otherwise, take the size
2071+
// of the inner tensor::ExtractSliceOp.
2072+
SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
2073+
for (auto &en : enumerate(newSizes)) {
2074+
if (!outerDims.test(en.index()))
2075+
continue;
2076+
OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
2077+
int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
2078+
assert(!ShapedType::isDynamic(sourceSize) &&
2079+
"expected padded dimension to have a static size");
2080+
if (getConstantIntValue(sliceSize) != sourceSize) {
2081+
return rewriter.notifyMatchFailure(
2082+
padOp, "cannot fold since the inner ExtractSliceOp size does not "
2083+
"match the size of the outer padding");
2084+
}
2085+
en.value() = outerSliceOp.getMixedSizes()[en.index()];
2086+
}
2087+
2088+
// Combine the high paddings of the two tensor::PadOps.
2089+
SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
2090+
for (auto &en : enumerate(newHighPad)) {
2091+
if (innerDims.test(en.index()))
2092+
newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
2093+
if (outerDims.test(en.index()))
2094+
newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
2095+
}
2096+
2097+
// Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs the
2098+
// two paddings in one step.
2099+
auto newSliceOp = rewriter.create<ExtractSliceOp>(
2100+
padOp.getLoc(), outerSliceOp.source(), newOffsets, newSizes,
2101+
innerSliceOp.getMixedStrides());
2102+
auto newPadOp = rewriter.create<PadOp>(
2103+
padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
2104+
padOp.getMixedLowPad(), newHighPad, padOp.nofold());
2105+
rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
2106+
newPadOp.getRegion().begin());
2107+
rewriter.replaceOp(padOp, newPadOp.getResult());
2108+
return success();
2109+
}
2110+
};
2111+
19432112
} // namespace
19442113

19452114
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
19462115
MLIRContext *context) {
1947-
results
1948-
.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast>(
1949-
context);
2116+
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
2117+
FoldOrthogonalPaddings>(context);
19502118
}
19512119

19522120
/// Return the padding value of the PadOp if it constant. In this context,

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,6 +1252,95 @@ func @pad_nofold_static_zero(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tenso
12521252

12531253
// -----
12541254

1255+
// CHECK-LABEL: func @fold_orthogonal_pad_chains(
1256+
// CHECK-SAME: %[[ARG0:.*]]: tensor<64x64xf32>,
1257+
// CHECK-SAME: %[[SZ0:.*]]: index, %[[SZ1:.*]]: index, %[[PW0:.*]]: index, %[[PW1:.*]]: index
1258+
func.func @fold_orthogonal_pad_chains(%arg0: tensor<64x64xf32>,
1259+
%sz0 : index, %sz1 : index,
1260+
%pw0 : index, %pw1 : index) -> tensor<8x4xf32> {
1261+
// CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
1262+
// CHECK-SAME: [16, 4] [%[[SZ0]], %[[SZ1]]]
1263+
// CHECK: %[[PAD:.*]] = tensor.pad %[[T0]] nofold
1264+
// CHECK-SAME: high[%[[PW0]], %[[PW1]]]
1265+
// CHECK: return %[[PAD]]
1266+
%pad_value = arith.constant 0.0 : f32
1267+
%0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor<?x64xf32>
1268+
%1 = tensor.pad %0 low[0, 0] high[%pw0, 0] {
1269+
^bb0(%arg1: index, %arg2: index):
1270+
tensor.yield %pad_value : f32
1271+
} : tensor<?x64xf32> to tensor<8x64xf32>
1272+
%2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
1273+
%3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] {
1274+
^bb0(%arg1: index, %arg2: index):
1275+
tensor.yield %pad_value : f32
1276+
} : tensor<8x?xf32> to tensor<8x4xf32>
1277+
func.return %3 : tensor<8x4xf32>
1278+
}
1279+
1280+
// -----
1281+
1282+
// CHECK-LABEL: func @dont_fold_pad_chains(
1283+
// CHECK-SAME: %[[ARG0:.*]]: tensor<64x64xf32>,
1284+
// CHECK-SAME: %[[SZ0:.*]]: index, %[[SZ1:.*]]: index, %[[PW0:.*]]: index, %[[PW1:.*]]: index
1285+
func.func @dont_fold_pad_chains(%arg0: tensor<64x64xf32>,
1286+
%sz0 : index, %sz1 : index,
1287+
%pw0 : index, %pw1 : index) -> (tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32>) {
1288+
// CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
1289+
// CHECK: %[[T1:.*]] = tensor.pad %[[T0]]
1290+
%pad_value = arith.constant 0.0 : f32
1291+
%0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor<?x64xf32>
1292+
%1 = tensor.pad %0 low[0, 0] high[%pw0, 0] {
1293+
^bb0(%arg1: index, %arg2: index):
1294+
tensor.yield %pad_value : f32
1295+
} : tensor<?x64xf32> to tensor<8x64xf32>
1296+
1297+
// Don't fold if the padding values are different.
1298+
// CHECK: %[[T2:.*]] = tensor.extract_slice %[[T1]]
1299+
// CHECK-SAME: [0, 4] [8, %[[SZ1]]]
1300+
// CHECK: %[[PAD0:.*]] = tensor.pad %[[T2]]
1301+
%different_value = arith.constant 1.0 : f32
1302+
%2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
1303+
%3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] {
1304+
^bb0(%arg1: index, %arg2: index):
1305+
tensor.yield %different_value : f32
1306+
} : tensor<8x?xf32> to tensor<8x4xf32>
1307+
1308+
// Don't fold if the pad ops have common padding dimensions.
1309+
// CHECK: %[[T3:.*]] = tensor.extract_slice %[[T1]]
1310+
// CHECK-SAME: [4, 0] [%[[SZ1]], 64]
1311+
// CHECK: %[[PAD1:.*]] = tensor.pad %[[T3]]
1312+
%4 = tensor.extract_slice %1[4, 0] [%sz1, 64] [1, 1] : tensor<8x64xf32> to tensor<?x64xf32>
1313+
%5 = tensor.pad %4 nofold low[0, 0] high[%pw1, 0] {
1314+
^bb0(%arg1: index, %arg2: index):
1315+
tensor.yield %pad_value : f32
1316+
} : tensor<?x64xf32> to tensor<4x64xf32>
1317+
1318+
// Don't fold if padded source tensor dimension is accessed at an offset.
1319+
// CHECK: %[[T4:.*]] = tensor.extract_slice %[[T1]]
1320+
// CHECK-SAME: [%[[SZ0]], 4] [8, %[[SZ1]]
1321+
// CHECK: %[[PAD2:.*]] = tensor.pad %[[T4]]
1322+
%6 = tensor.extract_slice %1[%sz0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
1323+
%7 = tensor.pad %6 nofold low[0, 0] high[0, %pw1] {
1324+
^bb0(%arg1: index, %arg2: index):
1325+
tensor.yield %pad_value : f32
1326+
} : tensor<8x?xf32> to tensor<8x4xf32>
1327+
1328+
// Don't fold if a padded source tensor dimension is sliced.
1329+
// CHECK: %[[T5:.*]] = tensor.extract_slice %[[T1]]
1330+
// CHECK-SAME: [0, 4] [6, %[[SZ1]]
1331+
// CHECK: %[[PAD3:.*]] = tensor.pad %[[T5]]
1332+
%8 = tensor.extract_slice %1[0, 4] [6, %sz1] [1, 1] : tensor<8x64xf32> to tensor<6x?xf32>
1333+
%9 = tensor.pad %8 nofold low[0, 0] high[0, %pw1] {
1334+
^bb0(%arg1: index, %arg2: index):
1335+
tensor.yield %pad_value : f32
1336+
} : tensor<6x?xf32> to tensor<6x4xf32>
1337+
1338+
// CHECK: return %[[PAD0]], %[[PAD1]], %[[PAD2]], %[[PAD3]]
1339+
func.return %3, %5, %7, %9 : tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32>
1340+
}
1341+
1342+
// -----
1343+
12551344
// CHECK-LABEL: func @fold_collapse_shape_from_elements
12561345
func @fold_collapse_shape_from_elements(%arg0: i32) -> tensor<i32> {
12571346
// CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<i32>

0 commit comments

Comments
 (0)