@@ -1986,12 +1986,26 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
1986
1986
return fromElementsOp.getElements ()[flatIndex];
1987
1987
}
1988
1988
1989
- OpFoldResult ExtractOp::fold (FoldAdaptor) {
1989
+ // / Fold an insert or extract operation into an poison value when a poison index
1990
+ // / is found at any dimension of the static position.
1991
+ static ub::PoisonAttr
1992
+ foldPoisonIndexInsertExtractOp (MLIRContext *context,
1993
+ ArrayRef<int64_t > staticPos, int64_t poisonVal) {
1994
+ if (!llvm::is_contained (staticPos, poisonVal))
1995
+ return ub::PoisonAttr ();
1996
+
1997
+ return ub::PoisonAttr::get (context);
1998
+ }
1999
+
2000
+ OpFoldResult ExtractOp::fold (FoldAdaptor adaptor) {
1990
2001
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
1991
2002
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
1992
2003
// mismatch).
1993
2004
if (getNumIndices () == 0 && getVector ().getType () == getResult ().getType ())
1994
2005
return getVector ();
2006
+ if (auto res = foldPoisonIndexInsertExtractOp (
2007
+ getContext (), adaptor.getStaticPosition (), kPoisonIndex ))
2008
+ return res;
1995
2009
if (succeeded (foldExtractOpFromExtractChain (*this )))
1996
2010
return getResult ();
1997
2011
if (auto res = ExtractFromInsertTransposeChainState (*this ).fold ())
@@ -2262,13 +2276,15 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2262
2276
// / Fold an insert or extract operation into an poison value when a poison index
2263
2277
// / is found at any dimension of the static position.
2264
2278
template <typename OpTy>
2265
- LogicalResult foldPoisonIndexInsertExtractOp (OpTy op,
2266
- PatternRewriter &rewriter) {
2267
- if (!llvm::is_contained (op.getStaticPosition (), OpTy::kPoisonIndex ))
2268
- return failure ();
2279
+ LogicalResult
2280
+ canonicalizePoisonIndexInsertExtractOp (OpTy op, PatternRewriter &rewriter) {
2281
+ if (auto poisonAttr = foldPoisonIndexInsertExtractOp (
2282
+ op.getContext (), op.getStaticPosition (), OpTy::kPoisonIndex )) {
2283
+ rewriter.replaceOpWithNewOp <ub::PoisonOp>(op, op.getType (), poisonAttr);
2284
+ return success ();
2285
+ }
2269
2286
2270
- rewriter.replaceOpWithNewOp <ub::PoisonOp>(op, op.getResult ().getType ());
2271
- return success ();
2287
+ return failure ();
2272
2288
}
2273
2289
2274
2290
} // namespace
@@ -2279,7 +2295,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2279
2295
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2280
2296
results.add (foldExtractFromShapeCastToShapeCast);
2281
2297
results.add (foldExtractFromFromElements);
2282
- results.add (foldPoisonIndexInsertExtractOp <ExtractOp>);
2298
+ results.add (canonicalizePoisonIndexInsertExtractOp <ExtractOp>);
2283
2299
}
2284
2300
2285
2301
static void populateFromInt64AttrArray (ArrayAttr arrayAttr,
@@ -3044,7 +3060,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3044
3060
MLIRContext *context) {
3045
3061
results.add <InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3046
3062
InsertOpConstantFolder>(context);
3047
- results.add (foldPoisonIndexInsertExtractOp <InsertOp>);
3063
+ results.add (canonicalizePoisonIndexInsertExtractOp <InsertOp>);
3048
3064
}
3049
3065
3050
3066
OpFoldResult vector::InsertOp::fold (FoldAdaptor adaptor) {
@@ -3053,6 +3069,10 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3053
3069
// (type mismatch).
3054
3070
if (getNumIndices () == 0 && getSourceType () == getType ())
3055
3071
return getSource ();
3072
+ if (auto res = foldPoisonIndexInsertExtractOp (
3073
+ getContext (), adaptor.getStaticPosition (), kPoisonIndex ))
3074
+ return res;
3075
+
3056
3076
return {};
3057
3077
}
3058
3078
0 commit comments