@@ -2673,43 +2673,51 @@ static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
2673
2673
}
2674
2674
2675
2675
OpFoldResult vector::ShuffleOp::fold (FoldAdaptor adaptor) {
2676
- VectorType v1Type = getV1VectorType ();
2676
+ auto v1Type = getV1VectorType ();
2677
+ auto v2Type = getV2VectorType ();
2678
+
2679
+ assert (!v1Type.isScalable () && !v2Type.isScalable () &&
2680
+ " Vector shuffle does not support scalable vectors" );
2681
+
2677
2682
// For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2678
2683
// but must be a canonicalization into a vector.broadcast.
2679
2684
if (v1Type.getRank () == 0 )
2680
2685
return {};
2681
2686
2682
- // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
2683
- if (!v1Type. isScalable () &&
2684
- isStepIndexArray (getMask () , 0 , v1Type.getDimSize (0 )))
2687
+ // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
2688
+ auto mask = getMask ();
2689
+ if ( isStepIndexArray (mask , 0 , v1Type.getDimSize (0 )))
2685
2690
return getV1 ();
2686
- // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
2687
- if (!getV1VectorType ().isScalable () && !getV2VectorType ().isScalable () &&
2688
- isStepIndexArray (getMask (), getV1VectorType ().getDimSize (0 ),
2689
- getV2VectorType ().getDimSize (0 )))
2691
+ // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
2692
+ if (isStepIndexArray (mask, v1Type.getDimSize (0 ), v2Type.getDimSize (0 )))
2690
2693
return getV2 ();
2691
2694
2692
- Attribute lhs = adaptor.getV1 (), rhs = adaptor.getV2 ();
2693
- if (!lhs || !rhs )
2695
+ Attribute v1Attr = adaptor.getV1 (), v2Attr = adaptor.getV2 ();
2696
+ if (!v1Attr || !v2Attr )
2694
2697
return {};
2695
2698
2696
- auto lhsType =
2697
- llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType ());
2698
2699
// Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2699
2700
// manipulation.
2700
- if (lhsType .getRank () != 1 )
2701
+ if (v1Type .getRank () != 1 )
2701
2702
return {};
2702
- int64_t lhsSize = lhsType.getDimSize (0 );
2703
+
2704
+ int64_t v1Size = v1Type.getDimSize (0 );
2703
2705
2704
2706
SmallVector<Attribute> results;
2705
- auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues <Attribute>();
2706
- auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues <Attribute>();
2707
- for (int64_t i : this ->getMask ()) {
2708
- if (i >= lhsSize) {
2709
- results.push_back (rhsElements[i - lhsSize]);
2707
+ auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues <Attribute>();
2708
+ auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues <Attribute>();
2709
+ for (int64_t maskIdx : mask) {
2710
+ Attribute indexedElm;
2711
+ // Select v1[0] for poison indices.
2712
+ // TODO: Return a partial poison vector when supported by the UB dialect.
2713
+ if (maskIdx == ShuffleOp::kPoisonIndex ) {
2714
+ indexedElm = v1Elements[0 ];
2710
2715
} else {
2711
- results.push_back (lhsElements[i]);
2716
+ indexedElm =
2717
+ maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
2712
2718
}
2719
+
2720
+ results.push_back (indexedElm);
2713
2721
}
2714
2722
2715
2723
return DenseElementsAttr::get (getResultVectorType (), results);
0 commit comments