26
26
#include " mlir/IR/AffineMap.h"
27
27
#include " mlir/IR/Builders.h"
28
28
#include " mlir/IR/BuiltinAttributes.h"
29
- #include " mlir/IR/BuiltinOps.h"
30
29
#include " mlir/IR/BuiltinTypes.h"
31
30
#include " mlir/IR/DialectImplementation.h"
32
31
#include " mlir/IR/IRMapping.h"
42
41
#include " llvm/ADT/SmallVector.h"
43
42
#include " llvm/ADT/StringSet.h"
44
43
#include " llvm/ADT/TypeSwitch.h"
45
- #include " llvm/ADT/bit.h"
46
44
47
45
#include < cassert>
48
46
#include < cstdint>
@@ -2684,25 +2682,45 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
2684
2682
if (!v1Attr || !v2Attr)
2685
2683
return {};
2686
2684
2685
+ // Fold shuffle poison, poison -> poison.
2686
+ bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
2687
+ bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
2688
+ if (isV1Poison && isV2Poison)
2689
+ return ub::PoisonAttr::get (getContext ());
2690
+
2687
2691
// Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2688
2692
// manipulation.
2689
2693
if (v1Type.getRank () != 1 )
2690
2694
return {};
2691
2695
2692
- int64_t v1Size = v1Type.getDimSize (0 );
2696
+ // Poison input attributes need special handling as they are not
2697
+ // DenseElementsAttr. If an index is poison, we select the first element of
2698
+ // the first non-poison input.
2699
+ SmallVector<Attribute> v1Elements, v2Elements;
2700
+ Attribute poisonElement;
2701
+ if (!isV2Poison) {
2702
+ v2Elements =
2703
+ to_vector (cast<DenseElementsAttr>(v2Attr).getValues <Attribute>());
2704
+ poisonElement = v2Elements[0 ];
2705
+ }
2706
+ if (!isV1Poison) {
2707
+ v1Elements =
2708
+ to_vector (cast<DenseElementsAttr>(v1Attr).getValues <Attribute>());
2709
+ poisonElement = v1Elements[0 ];
2710
+ }
2693
2711
2694
2712
SmallVector<Attribute> results;
2695
- auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues <Attribute>();
2696
- auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues <Attribute>();
2713
+ int64_t v1Size = v1Type.getDimSize (0 );
2697
2714
for (int64_t maskIdx : mask) {
2698
2715
Attribute indexedElm;
2699
- // Select v1[0] for poison indices.
2700
2716
// TODO: Return a partial poison vector when supported by the UB dialect.
2701
2717
if (maskIdx == ShuffleOp::kPoisonIndex ) {
2702
- indexedElm = v1Elements[ 0 ] ;
2718
+ indexedElm = poisonElement ;
2703
2719
} else {
2704
- indexedElm =
2705
- maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
2720
+ if (maskIdx < v1Size)
2721
+ indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
2722
+ else
2723
+ indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
2706
2724
}
2707
2725
2708
2726
results.push_back (indexedElm);
@@ -3319,13 +3337,15 @@ class InsertStridedSliceConstantFolder final
3319
3337
!destVector.hasOneUse ())
3320
3338
return failure ();
3321
3339
3322
- auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3323
-
3324
3340
TypedValue<VectorType> sourceValue = op.getSource ();
3325
3341
Attribute sourceCst;
3326
3342
if (!matchPattern (sourceValue, m_Constant (&sourceCst)))
3327
3343
return failure ();
3328
3344
3345
+ // TODO: Support poison.
3346
+ if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3347
+ return failure ();
3348
+
3329
3349
// TODO: Handle non-unit strides when they become available.
3330
3350
if (op.hasNonUnitStrides ())
3331
3351
return failure ();
@@ -3342,6 +3362,7 @@ class InsertStridedSliceConstantFolder final
3342
3362
// increasing linearized position indices.
3343
3363
// Because the destination may have higher dimensionality then the slice,
3344
3364
// we keep track of two overlapping sets of positions and offsets.
3365
+ auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3345
3366
auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3346
3367
auto sliceValuesIt = denseSlice.value_begin <Attribute>();
3347
3368
auto newValues = llvm::to_vector (denseDest.getValues <Attribute>());
0 commit comments