@@ -2780,28 +2780,39 @@ static Value *interleaveVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vals,
2780
2780
// Scalable vectors cannot use arbitrary shufflevectors (only splats), so
2781
2781
// must use intrinsics to interleave.
2782
2782
if (VecTy->isScalableTy ()) {
2783
+ if (Vals.size () == 2 ) {
2784
+ VectorType *WideVecTy = VectorType::getDoubleElementsVectorType (VecTy);
2785
+ return Builder.CreateIntrinsic (WideVecTy, Intrinsic::vector_interleave2,
2786
+ Vals,
2787
+ /* FMFSource=*/ nullptr , Name);
2788
+ }
2783
2789
unsigned InterleaveFactor = Vals.size ();
2784
- SmallVector<Value *> InterleavingValues;
2785
- unsigned InterleavingValuesCount =
2786
- InterleaveFactor + (InterleaveFactor - 2 );
2787
- InterleavingValues.resize (InterleaveFactor);
2788
- // Place the values to be interleaved in the correct order for the
2789
- // interleaving
2790
- for (unsigned I = 0 , J = InterleaveFactor / 2 , K = 0 ; K < InterleaveFactor;
2791
- K++) {
2792
- if (K % 2 == 0 ) {
2793
- InterleavingValues[K] = Vals[I];
2794
- I++;
2795
- } else {
2796
- InterleavingValues[K] = Vals[J];
2797
- J++;
2790
+ SmallVector<Value *> InterleavingValues (Vals);
2791
+ // The total number of nodes in a balanced binary tree is calculated as 2n -
2792
+ // 1, where `n` is the number of leaf nodes (`InterleaveFactor`). In this
2793
+ // context, we exclude the root node because it will serve as the final
2794
+ // interleaved value. Thus, the number of nodes to be processed/interleaved
2795
+ // is: (2n - 1) - 1 = 2n - 2.
2796
+
2797
+ unsigned NumInterleavingValues = 2 * InterleaveFactor - 2 ;
2798
+ for (unsigned I = 1 ; I < NumInterleavingValues; I += 2 ) {
2799
+ // values that haven't been processed yet:
2800
+ unsigned Remaining = InterleavingValues.size () - I + 1 ;
2801
+ if (Remaining > 2 && isPowerOf2_32 (Remaining)) {
2802
+
2803
+ // The remaining values form a new level in the interleaving tree.
2804
+ // Arrange these values in the correct interleaving order for this
2805
+ // level. The interleaving order places alternating elements from the
2806
+ // first and second halves,
2807
+ std::vector<Value *> RemainingValues (InterleavingValues.begin () + I - 1 ,
2808
+ InterleavingValues.end ());
2809
+ unsigned Middle = Remaining / 2 ;
2810
+ for (unsigned J = I - 1 , K = 0 ; J < InterleavingValues.size ();
2811
+ J += 2 , K++) {
2812
+ InterleavingValues[J] = RemainingValues[K];
2813
+ InterleavingValues[J + 1 ] = RemainingValues[Middle + K];
2814
+ }
2798
2815
}
2799
- }
2800
- #ifndef NDEBUG
2801
- for (Value *Val : InterleavingValues)
2802
- assert (Val && " NULL Interleaving Value" );
2803
- #endif
2804
- for (unsigned I = 1 ; I < InterleavingValuesCount; I += 2 ) {
2805
2816
VectorType *InterleaveTy =
2806
2817
cast<VectorType>(InterleavingValues[I]->getType ());
2807
2818
VectorType *WideVecTy =
@@ -2812,7 +2823,7 @@ static Value *interleaveVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vals,
2812
2823
/* FMFSource=*/ nullptr , Name);
2813
2824
InterleavingValues.push_back (InterleaveRes);
2814
2825
}
2815
- return InterleavingValues[InterleavingValuesCount ];
2826
+ return InterleavingValues[NumInterleavingValues ];
2816
2827
}
2817
2828
2818
2829
// Fixed length. Start by concatenating all vectors into a wide vector.
@@ -2951,42 +2962,48 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
2951
2962
2952
2963
SmallVector<Value *> DeinterleavedValues;
2953
2964
// If the InterleaveFactor is > 2, so we will have to do recursive
2954
- // deinterleaving, because the current available deinterleave intrinsice
2965
+ // deinterleaving, because the current available deinterleave intrinsic
2955
2966
// supports only Factor of 2. DeinterleaveCount represent how many times
2956
2967
// we will do deinterleaving, we will do deinterleave on all nonleaf
2957
2968
// nodes in the deinterleave tree.
2958
2969
unsigned DeinterleaveCount = InterleaveFactor - 1 ;
2959
- std::queue <Value *> TempDeinterleavedValues;
2960
- TempDeinterleavedValues.push (NewLoad);
2970
+ std::vector <Value *> TempDeinterleavedValues;
2971
+ TempDeinterleavedValues.push_back (NewLoad);
2961
2972
for (unsigned I = 0 ; I < DeinterleaveCount; ++I) {
2962
- Value *ValueToDeinterleave = TempDeinterleavedValues.front ();
2963
- auto *DiTy = ValueToDeinterleave->getType ();
2964
- TempDeinterleavedValues.pop ();
2973
+ auto *DiTy = TempDeinterleavedValues[I]->getType ();
2965
2974
Value *DI = State.Builder .CreateIntrinsic (
2966
- Intrinsic::vector_deinterleave2, DiTy, ValueToDeinterleave ,
2975
+ Intrinsic::vector_deinterleave2, DiTy, TempDeinterleavedValues[I] ,
2967
2976
/* FMFSource=*/ nullptr , " strided.vec" );
2968
2977
Value *StridedVec = State.Builder .CreateExtractValue (DI, 0 );
2969
- TempDeinterleavedValues.push (StridedVec);
2978
+ TempDeinterleavedValues.push_back (StridedVec);
2970
2979
StridedVec = State.Builder .CreateExtractValue (DI, 1 );
2971
- TempDeinterleavedValues.push (StridedVec);
2972
- }
2973
-
2974
- assert (TempDeinterleavedValues.size () == InterleaveFactor &&
2975
- " Num of deinterleaved values must equals to InterleaveFactor" );
2976
- // Sort deinterleaved values
2977
- DeinterleavedValues.resize (InterleaveFactor);
2978
- for (unsigned I = 0 , J = InterleaveFactor / 2 , K = 0 ;
2979
- K < InterleaveFactor; K++) {
2980
- auto *DeinterleavedValue = TempDeinterleavedValues.front ();
2981
- TempDeinterleavedValues.pop ();
2982
- if (K % 2 == 0 ) {
2983
- DeinterleavedValues[I] = DeinterleavedValue;
2984
- I++;
2985
- } else {
2986
- DeinterleavedValues[J] = DeinterleavedValue;
2987
- J++;
2980
+ TempDeinterleavedValues.push_back (StridedVec);
2981
+ // Perform sorting at the start of each new level in the tree.
2982
+ // A new level begins when the number of remaining values is a power
2983
+ // of 2 and greater than 2. If a level has only 2 nodes, no sorting is
2984
+ // needed as they are already in order. Number of remaining values to
2985
+ // be processed:
2986
+ unsigned NumRemainingValues = TempDeinterleavedValues.size () - I - 1 ;
2987
+ if (NumRemainingValues > 2 && isPowerOf2_32 (NumRemainingValues)) {
2988
+ // these remaining values represent a new level in the tree,
2989
+ // Reorder the values to match the correct deinterleaving order.
2990
+ std::vector<Value *> RemainingValues (
2991
+ TempDeinterleavedValues.begin () + I + 1 ,
2992
+ TempDeinterleavedValues.end ());
2993
+ unsigned Middle = NumRemainingValues / 2 ;
2994
+ for (unsigned J = 0 , K = I + 1 ; J < NumRemainingValues;
2995
+ J += 2 , K++) {
2996
+ TempDeinterleavedValues[K] = RemainingValues[J];
2997
+ TempDeinterleavedValues[Middle + K] = RemainingValues[J + 1 ];
2998
+ }
2988
2999
}
2989
3000
}
3001
+ // Final deinterleaved values:
3002
+ DeinterleavedValues.insert (DeinterleavedValues.begin (),
3003
+ TempDeinterleavedValues.begin () +
3004
+ InterleaveFactor - 1 ,
3005
+ TempDeinterleavedValues.end ());
3006
+
2990
3007
#ifndef NDEBUG
2991
3008
for (Value *Val : DeinterleavedValues)
2992
3009
assert (Val && " NULL Deinterleaved Value" );
0 commit comments