@@ -3013,94 +3013,78 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3013
3013
}
3014
3014
};
3015
3015
3016
- // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
3017
- class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
3018
- public:
3019
- using OpRewritePattern::OpRewritePattern;
3020
-
3021
- // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3022
- // unless the source vector constant has a single use.
3023
- static constexpr int64_t vectorSizeFoldThreshold = 256 ;
3024
-
3025
- LogicalResult matchAndRewrite (InsertOp op,
3026
- PatternRewriter &rewriter) const override {
3027
- // TODO: Canonicalization for dynamic position not implemented yet.
3028
- if (op.hasDynamicPosition ())
3029
- return failure ();
3016
+ } // namespace
3030
3017
3031
- // Return if 'InsertOp' operand is not defined by a compatible vector
3032
- // ConstantOp.
3033
- TypedValue<VectorType> destVector = op.getDest ();
3034
- Attribute vectorDestCst;
3035
- if (!matchPattern (destVector, m_Constant (&vectorDestCst)))
3036
- return failure ();
3037
- auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
3038
- if (!denseDest)
3039
- return failure ();
3018
+ static Attribute
3019
+ foldDenseElementsAttrDestInsertOp (InsertOp insertOp, Attribute srcAttr,
3020
+ Attribute dstAttr,
3021
+ int64_t maxVectorSizeFoldThreshold) {
3022
+ if (insertOp.hasDynamicPosition ())
3023
+ return {};
3040
3024
3041
- VectorType destTy = destVector. getType ( );
3042
- if (destTy. isScalable () )
3043
- return failure () ;
3025
+ auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr );
3026
+ if (!denseDst )
3027
+ return {} ;
3044
3028
3045
- // Make sure we do not create too many large constants.
3046
- if (destTy.getNumElements () > vectorSizeFoldThreshold &&
3047
- !destVector.hasOneUse ())
3048
- return failure ();
3029
+ if (!srcAttr) {
3030
+ return {};
3031
+ }
3049
3032
3050
- Value sourceValue = op.getSource ();
3051
- Attribute sourceCst;
3052
- if (!matchPattern (sourceValue, m_Constant (&sourceCst)))
3053
- return failure ();
3033
+ VectorType destTy = insertOp.getDestVectorType ();
3034
+ if (destTy.isScalable ())
3035
+ return {};
3054
3036
3055
- // Calculate the linearized position of the continuous chunk of elements to
3056
- // insert.
3057
- llvm::SmallVector<int64_t > completePositions (destTy.getRank (), 0 );
3058
- copy (op.getStaticPosition (), completePositions.begin ());
3059
- int64_t insertBeginPosition =
3060
- linearize (completePositions, computeStrides (destTy.getShape ()));
3061
-
3062
- SmallVector<Attribute> insertedValues;
3063
- Type destEltType = destTy.getElementType ();
3064
-
3065
- // The `convertIntegerAttr` method specifically handles the case
3066
- // for `llvm.mlir.constant` which can hold an attribute with a
3067
- // different type than the return type.
3068
- if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
3069
- for (auto value : denseSource.getValues <Attribute>())
3070
- insertedValues.push_back (convertIntegerAttr (value, destEltType));
3071
- } else {
3072
- insertedValues.push_back (convertIntegerAttr (sourceCst, destEltType));
3073
- }
3037
+ // Make sure we do not create too many large constants.
3038
+ if (destTy.getNumElements () > maxVectorSizeFoldThreshold &&
3039
+ !insertOp->hasOneUse ())
3040
+ return {};
3074
3041
3075
- auto allValues = llvm::to_vector (denseDest.getValues <Attribute>());
3076
- copy (insertedValues, allValues.begin () + insertBeginPosition);
3077
- auto newAttr = DenseElementsAttr::get (destTy, allValues);
3042
+ // Calculate the linearized position of the continuous chunk of elements to
3043
+ // insert.
3044
+ llvm::SmallVector<int64_t > completePositions (destTy.getRank (), 0 );
3045
+ copy (insertOp.getStaticPosition (), completePositions.begin ());
3046
+ int64_t insertBeginPosition =
3047
+ linearize (completePositions, computeStrides (destTy.getShape ()));
3078
3048
3079
- rewriter.replaceOpWithNewOp <arith::ConstantOp>(op, newAttr);
3080
- return success ();
3081
- }
3049
+ SmallVector<Attribute> insertedValues;
3050
+ Type destEltType = destTy.getElementType ();
3082
3051
3083
- private:
3084
3052
// / Converts the expected type to an IntegerAttr if there's
3085
3053
// / a mismatch.
3086
- Attribute convertIntegerAttr (Attribute attr, Type expectedType) const {
3054
+ auto convertIntegerAttr = [] (Attribute attr, Type expectedType) -> Attribute {
3087
3055
if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3088
3056
if (intAttr.getType () != expectedType)
3089
3057
return IntegerAttr::get (expectedType, intAttr.getInt ());
3090
3058
}
3091
3059
return attr;
3060
+ };
3061
+
3062
+ // The `convertIntegerAttr` method specifically handles the case
3063
+ // for `llvm.mlir.constant` which can hold an attribute with a
3064
+ // different type than the return type.
3065
+ if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3066
+ for (auto value : denseSource.getValues <Attribute>())
3067
+ insertedValues.push_back (convertIntegerAttr (value, destEltType));
3068
+ } else {
3069
+ insertedValues.push_back (convertIntegerAttr (srcAttr, destEltType));
3092
3070
}
3093
- };
3094
3071
3095
- } // namespace
3072
+ auto allValues = llvm::to_vector (denseDst.getValues <Attribute>());
3073
+ copy (insertedValues, allValues.begin () + insertBeginPosition);
3074
+ auto newAttr = DenseElementsAttr::get (destTy, allValues);
3075
+
3076
+ return newAttr;
3077
+ }
3096
3078
3097
3079
void InsertOp::getCanonicalizationPatterns (RewritePatternSet &results,
3098
3080
MLIRContext *context) {
3099
- results.add <InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3100
- InsertOpConstantFolder>(context);
3081
+ results.add <InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
3101
3082
}
3102
3083
3103
3084
OpFoldResult vector::InsertOp::fold (FoldAdaptor adaptor) {
3085
+ // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3086
+ // unless the source vector constant has a single use.
3087
+ constexpr int64_t vectorSizeFoldThreshold = 256 ;
3104
3088
// Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3105
3089
// %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3106
3090
// (type mismatch).
@@ -3112,6 +3096,11 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3112
3096
if (auto res = foldPoisonIndexInsertExtractOp (
3113
3097
getContext (), adaptor.getStaticPosition (), kPoisonIndex ))
3114
3098
return res;
3099
+ if (auto res = foldDenseElementsAttrDestInsertOp (*this , adaptor.getSource (),
3100
+ adaptor.getDest (),
3101
+ vectorSizeFoldThreshold)) {
3102
+ return res;
3103
+ }
3115
3104
3116
3105
return {};
3117
3106
}
0 commit comments