@@ -1858,6 +1858,18 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
1858
1858
result.addAttributes (attrs);
1859
1859
}
1860
1860
1861
+ llvm::SmallBitVector PadOp::getPaddedDims () {
1862
+ llvm::SmallBitVector paddedDims (getSourceType ().getRank ());
1863
+ auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
1864
+ for (const auto &en : enumerate(paddingWidths))
1865
+ if (getConstantIntValue (en.value ()) != static_cast <int64_t >(0 ))
1866
+ paddedDims.set (en.index ());
1867
+ };
1868
+ extractPaddedDims (getMixedLowPad ());
1869
+ extractPaddedDims (getMixedHighPad ());
1870
+ return paddedDims;
1871
+ }
1872
+
1861
1873
namespace {
1862
1874
// Folds tensor.pad when padding is static zeros and the attribute
1863
1875
// doesn't request otherwise.
@@ -1940,13 +1952,169 @@ struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
1940
1952
return success ();
1941
1953
}
1942
1954
};
1955
+
1956
+ // / Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
1957
+ // / different dimensions. The pattern applies if the following preconditions
1958
+ // / hold:
1959
+ // / 1) the tensor::ExtractSliceOps are not rank-reducing,
1960
+ // / 2) the tensor::ExtractSliceOps have only unit-strides,
1961
+ // / 3) the tensor::PadOps perform only high-padding,
1962
+ // / 4) the tensor::PadOps have the same constant padding value,
1963
+ // / 5) the tensor::PadOps do not have common padding dimensions,
1964
+ // / 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
1965
+ // / zero-offset for every dimension.
1966
+ // / 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for the
1967
+ // / padded source dimensions.
1968
+ // /
1969
+ // / Example:
1970
+ // /
1971
+ // / ```mlir
1972
+ // / %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
1973
+ // / : tensor<64x64xf32> to tensor<?x64xf32>
1974
+ // / %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
1975
+ // / } : tensor<?x64xf32> to tensor<8x64xf32>
1976
+ // / %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
1977
+ // / : tensor<8x64xf32> to tensor<8x?xf32>
1978
+ // / %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
1979
+ // / } : tensor<8x?xf32> to tensor<8x4xf32>
1980
+ // / ```
1981
+ // /
1982
+ // / folds into:
1983
+ // /
1984
+ // / ```mlir
1985
+ // / %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
1986
+ // / : tensor<64x64xf32> to tensor<?x?xf32>
1987
+ // / %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
1988
+ // / } : tensor<?x?xf32> to tensor<8x4xf32>
1989
+ // / ```
1990
+ struct FoldOrthogonalPaddings : public OpRewritePattern <PadOp> {
1991
+ using OpRewritePattern<PadOp>::OpRewritePattern;
1992
+
1993
+ LogicalResult matchAndRewrite (PadOp padOp,
1994
+ PatternRewriter &rewriter) const override {
1995
+ auto innerSliceOp = padOp.source ().getDefiningOp <ExtractSliceOp>();
1996
+ if (!innerSliceOp)
1997
+ return failure ();
1998
+ auto outerPadOp = innerSliceOp.source ().getDefiningOp <PadOp>();
1999
+ if (!outerPadOp || outerPadOp.nofold ())
2000
+ return failure ();
2001
+ auto outerSliceOp = outerPadOp.source ().getDefiningOp <ExtractSliceOp>();
2002
+ if (!outerSliceOp)
2003
+ return failure ();
2004
+
2005
+ // 1) Fail if the chain is rank-reducing.
2006
+ int64_t rank = padOp.getSourceType ().getRank ();
2007
+ if (outerSliceOp.getSourceType ().getRank () != rank) {
2008
+ return rewriter.notifyMatchFailure (padOp,
2009
+ " cannot fold rank-reducing chain" );
2010
+ }
2011
+
2012
+ // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
2013
+ if (!innerSliceOp.hasUnitStride () || !outerSliceOp.hasUnitStride ()) {
2014
+ return rewriter.notifyMatchFailure (
2015
+ padOp, " cannot fold non-unit stride ExtractSliceOps" );
2016
+ }
2017
+
2018
+ // 3) Fail if the tensor::PadOps have non-zero low padding.
2019
+ if (!padOp.hasZeroLowPad () || !outerPadOp.hasZeroLowPad ()) {
2020
+ return rewriter.notifyMatchFailure (padOp,
2021
+ " cannot fold PadOps with low padding" );
2022
+ }
2023
+
2024
+ // 4) Fail if the tensor::PadOps padding values do not match.
2025
+ Attribute innerAttr, outerAttr;
2026
+ Value innerValue = padOp.getConstantPaddingValue ();
2027
+ Value outerValue = outerPadOp.getConstantPaddingValue ();
2028
+ if (!innerValue || !outerValue ||
2029
+ !matchPattern (innerValue, m_Constant (&innerAttr)) ||
2030
+ !matchPattern (outerValue, m_Constant (&outerAttr)) ||
2031
+ innerAttr != outerAttr) {
2032
+ return rewriter.notifyMatchFailure (
2033
+ padOp, " cannot fold PadOps with different padding values" );
2034
+ }
2035
+
2036
+ // 5) Fail if a dimension is padded by both tensor::PadOps.
2037
+ llvm::SmallBitVector innerDims = padOp.getPaddedDims ();
2038
+ llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims ();
2039
+ if (innerDims.anyCommon (outerDims)) {
2040
+ return rewriter.notifyMatchFailure (
2041
+ padOp, " cannot fold PadOps with common padding dimensions" );
2042
+ }
2043
+
2044
+ // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
2045
+ // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2046
+ // for every dimension, and use the offset the other pair. Fail if no
2047
+ // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
2048
+ // exists.
2049
+ SmallVector<OpFoldResult> newOffsets (rank, rewriter.getIndexAttr (0 ));
2050
+ for (auto &en : enumerate(newOffsets)) {
2051
+ OpFoldResult innerOffset = innerSliceOp.getMixedOffsets ()[en.index ()];
2052
+ OpFoldResult outerOffset = outerSliceOp.getMixedOffsets ()[en.index ()];
2053
+ if (!innerDims.test (en.index ()) &&
2054
+ (getConstantIntValue (innerOffset) == static_cast <int64_t >(0 ))) {
2055
+ en.value () = outerOffset;
2056
+ continue ;
2057
+ }
2058
+ if (!outerDims.test (en.index ()) &&
2059
+ (getConstantIntValue (outerOffset) == static_cast <int64_t >(0 ))) {
2060
+ en.value () = innerOffset;
2061
+ continue ;
2062
+ }
2063
+ return rewriter.notifyMatchFailure (
2064
+ padOp, " cannot find zero-offset and zero-padding pair" );
2065
+ }
2066
+
2067
+ // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size of
2068
+ // the outer tensor::ExtractSliceOp for the dimensions padded by the outer
2069
+ // tensor::PadOp and fail if the size of the inner tensor::ExtractSliceOp
2070
+ // does not match the size of the padded dimension. Otherwise, take the size
2071
+ // of the inner tensor::ExtractSliceOp.
2072
+ SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes ();
2073
+ for (auto &en : enumerate(newSizes)) {
2074
+ if (!outerDims.test (en.index ()))
2075
+ continue ;
2076
+ OpFoldResult sliceSize = innerSliceOp.getMixedSizes ()[en.index ()];
2077
+ int64_t sourceSize = innerSliceOp.getSourceType ().getShape ()[en.index ()];
2078
+ assert (!ShapedType::isDynamic (sourceSize) &&
2079
+ " expected padded dimension to have a static size" );
2080
+ if (getConstantIntValue (sliceSize) != sourceSize) {
2081
+ return rewriter.notifyMatchFailure (
2082
+ padOp, " cannot fold since the inner ExtractSliceOp size does not "
2083
+ " match the size of the outer padding" );
2084
+ }
2085
+ en.value () = outerSliceOp.getMixedSizes ()[en.index ()];
2086
+ }
2087
+
2088
+ // Combine the high paddings of the two tensor::PadOps.
2089
+ SmallVector<OpFoldResult> newHighPad (rank, rewriter.getIndexAttr (0 ));
2090
+ for (auto &en : enumerate(newHighPad)) {
2091
+ if (innerDims.test (en.index ()))
2092
+ newHighPad[en.index ()] = padOp.getMixedHighPad ()[en.index ()];
2093
+ if (outerDims.test (en.index ()))
2094
+ newHighPad[en.index ()] = outerPadOp.getMixedHighPad ()[en.index ()];
2095
+ }
2096
+
2097
+ // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs the
2098
+ // two paddings in one step.
2099
+ auto newSliceOp = rewriter.create <ExtractSliceOp>(
2100
+ padOp.getLoc (), outerSliceOp.source (), newOffsets, newSizes,
2101
+ innerSliceOp.getMixedStrides ());
2102
+ auto newPadOp = rewriter.create <PadOp>(
2103
+ padOp.getLoc (), padOp.getResultType (), newSliceOp.getResult (),
2104
+ padOp.getMixedLowPad (), newHighPad, padOp.nofold ());
2105
+ rewriter.inlineRegionBefore (padOp.getRegion (), newPadOp.getRegion (),
2106
+ newPadOp.getRegion ().begin ());
2107
+ rewriter.replaceOp (padOp, newPadOp.getResult ());
2108
+ return success ();
2109
+ }
2110
+ };
2111
+
1943
2112
} // namespace
1944
2113
1945
2114
void PadOp::getCanonicalizationPatterns (RewritePatternSet &results,
1946
2115
MLIRContext *context) {
1947
- results
1948
- .add <FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast>(
1949
- context);
2116
+ results.add <FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
2117
+ FoldOrthogonalPaddings>(context);
1950
2118
}
1951
2119
1952
2120
// / Return the padding value of the PadOp if it constant. In this context,
0 commit comments