@@ -2031,20 +2031,71 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
2031
2031
static Attribute foldPoisonIndexInsertExtractOp (MLIRContext *context,
2032
2032
ArrayRef<int64_t > staticPos,
2033
2033
int64_t poisonVal) {
2034
- if (!llvm:: is_contained (staticPos, poisonVal))
2034
+ if (!is_contained (staticPos, poisonVal))
2035
2035
return {};
2036
2036
2037
2037
return ub::PoisonAttr::get (context);
2038
2038
}
2039
2039
2040
2040
// / Fold a vector extract from is a poison source.
2041
2041
static Attribute foldPoisonSrcExtractOp (Attribute srcAttr) {
2042
- if (llvm:: isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2042
+ if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2043
2043
return srcAttr;
2044
2044
2045
2045
return {};
2046
2046
}
2047
2047
2048
+ // / Fold a vector extract extracting from a DenseElementsAttr.
2049
+ static Attribute foldDenseElementsAttrSrcExtractOp (ExtractOp extractOp,
2050
+ Attribute srcAttr) {
2051
+ auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2052
+ if (!denseAttr) {
2053
+ return {};
2054
+ }
2055
+
2056
+ if (denseAttr.isSplat ()) {
2057
+ Attribute newAttr = denseAttr.getSplatValue <Attribute>();
2058
+ if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType ()))
2059
+ newAttr = DenseElementsAttr::get (vecDstType, newAttr);
2060
+ return newAttr;
2061
+ }
2062
+
2063
+ auto vecTy = cast<VectorType>(extractOp.getSourceVectorType ());
2064
+ if (vecTy.isScalable ())
2065
+ return {};
2066
+
2067
+ if (extractOp.hasDynamicPosition ()) {
2068
+ return {};
2069
+ }
2070
+
2071
+ // Materializing subsets of a large constant array can generally lead to
2072
+ // explosion in IR size because of different combination of subsets that
2073
+ // can exist. However, vector.extract is a restricted form of subset
2074
+ // extract where you can only extract non-overlapping (or the same) subset for
2075
+ // a given rank of the subset. Because of this property, the IR size can only
2076
+ // increase at most by `rank * size(array)` from a single constant array being
2077
+ // extracted by multiple extracts.
2078
+
2079
+ // Calculate the linearized position of the continuous chunk of elements to
2080
+ // extract.
2081
+ SmallVector<int64_t > completePositions (vecTy.getRank (), 0 );
2082
+ copy (extractOp.getStaticPosition (), completePositions.begin ());
2083
+ int64_t startPos =
2084
+ linearize (completePositions, computeStrides (vecTy.getShape ()));
2085
+ auto denseValuesBegin = denseAttr.value_begin <TypedAttr>() + startPos;
2086
+
2087
+ TypedAttr newAttr;
2088
+ if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType ())) {
2089
+ SmallVector<Attribute> elementValues (
2090
+ denseValuesBegin, denseValuesBegin + resVecTy.getNumElements ());
2091
+ newAttr = DenseElementsAttr::get (resVecTy, elementValues);
2092
+ } else {
2093
+ newAttr = *denseValuesBegin;
2094
+ }
2095
+
2096
+ return newAttr;
2097
+ }
2098
+
2048
2099
OpFoldResult ExtractOp::fold (FoldAdaptor adaptor) {
2049
2100
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2050
2101
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2056,6 +2107,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2056
2107
return res;
2057
2108
if (auto res = foldPoisonSrcExtractOp (adaptor.getVector ()))
2058
2109
return res;
2110
+ if (auto res = foldDenseElementsAttrSrcExtractOp (*this , adaptor.getVector ()))
2111
+ return res;
2059
2112
if (succeeded (foldExtractOpFromExtractChain (*this )))
2060
2113
return getResult ();
2061
2114
if (auto res = ExtractFromInsertTransposeChainState (*this ).fold ())
@@ -2119,80 +2172,6 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2119
2172
}
2120
2173
};
2121
2174
2122
- // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
2123
- class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
2124
- public:
2125
- using OpRewritePattern::OpRewritePattern;
2126
-
2127
- LogicalResult matchAndRewrite (ExtractOp extractOp,
2128
- PatternRewriter &rewriter) const override {
2129
- // Return if 'ExtractOp' operand is not defined by a splat vector
2130
- // ConstantOp.
2131
- Value sourceVector = extractOp.getVector ();
2132
- Attribute vectorCst;
2133
- if (!matchPattern (sourceVector, m_Constant (&vectorCst)))
2134
- return failure ();
2135
- auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
2136
- if (!splat)
2137
- return failure ();
2138
- TypedAttr newAttr = splat.getSplatValue <TypedAttr>();
2139
- if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType ()))
2140
- newAttr = DenseElementsAttr::get (vecDstType, newAttr);
2141
- rewriter.replaceOpWithNewOp <arith::ConstantOp>(extractOp, newAttr);
2142
- return success ();
2143
- }
2144
- };
2145
-
2146
- // Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
2147
- class ExtractOpNonSplatConstantFolder final
2148
- : public OpRewritePattern<ExtractOp> {
2149
- public:
2150
- using OpRewritePattern::OpRewritePattern;
2151
-
2152
- LogicalResult matchAndRewrite (ExtractOp extractOp,
2153
- PatternRewriter &rewriter) const override {
2154
- // TODO: Canonicalization for dynamic position not implemented yet.
2155
- if (extractOp.hasDynamicPosition ())
2156
- return failure ();
2157
-
2158
- // Return if 'ExtractOp' operand is not defined by a compatible vector
2159
- // ConstantOp.
2160
- Value sourceVector = extractOp.getVector ();
2161
- Attribute vectorCst;
2162
- if (!matchPattern (sourceVector, m_Constant (&vectorCst)))
2163
- return failure ();
2164
-
2165
- auto vecTy = llvm::cast<VectorType>(sourceVector.getType ());
2166
- if (vecTy.isScalable ())
2167
- return failure ();
2168
-
2169
- // The splat case is handled by `ExtractOpSplatConstantFolder`.
2170
- auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
2171
- if (!dense || dense.isSplat ())
2172
- return failure ();
2173
-
2174
- // Calculate the linearized position of the continuous chunk of elements to
2175
- // extract.
2176
- llvm::SmallVector<int64_t > completePositions (vecTy.getRank (), 0 );
2177
- copy (extractOp.getStaticPosition (), completePositions.begin ());
2178
- int64_t elemBeginPosition =
2179
- linearize (completePositions, computeStrides (vecTy.getShape ()));
2180
- auto denseValuesBegin = dense.value_begin <TypedAttr>() + elemBeginPosition;
2181
-
2182
- TypedAttr newAttr;
2183
- if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType ())) {
2184
- SmallVector<Attribute> elementValues (
2185
- denseValuesBegin, denseValuesBegin + resVecTy.getNumElements ());
2186
- newAttr = DenseElementsAttr::get (resVecTy, elementValues);
2187
- } else {
2188
- newAttr = *denseValuesBegin;
2189
- }
2190
-
2191
- rewriter.replaceOpWithNewOp <arith::ConstantOp>(extractOp, newAttr);
2192
- return success ();
2193
- }
2194
- };
2195
-
2196
2175
// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2197
2176
class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2198
2177
public:
@@ -2330,8 +2309,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2330
2309
2331
2310
void ExtractOp::getCanonicalizationPatterns (RewritePatternSet &results,
2332
2311
MLIRContext *context) {
2333
- results.add <ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
2334
- ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2312
+ results.add <ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2335
2313
results.add (foldExtractFromShapeCastToShapeCast);
2336
2314
results.add (foldExtractFromFromElements);
2337
2315
}
0 commit comments