|
35 | 35 | #include "llvm/Transforms/Utils/LoopUtils.h"
|
36 | 36 | #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
|
37 | 37 | #include <cassert>
|
38 |
| -#include <queue> |
39 | 38 |
|
40 | 39 | using namespace llvm;
|
41 | 40 |
|
@@ -2780,50 +2779,22 @@ static Value *interleaveVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vals,
|
2780 | 2779 | // Scalable vectors cannot use arbitrary shufflevectors (only splats), so
|
2781 | 2780 | // must use intrinsics to interleave.
|
2782 | 2781 | if (VecTy->isScalableTy()) {
|
2783 |
| - if (Vals.size() == 2) { |
2784 |
| - VectorType *WideVecTy = VectorType::getDoubleElementsVectorType(VecTy); |
2785 |
| - return Builder.CreateIntrinsic(WideVecTy, Intrinsic::vector_interleave2, |
2786 |
| - Vals, |
2787 |
| - /*FMFSource=*/nullptr, Name); |
2788 |
| - } |
2789 | 2782 | unsigned InterleaveFactor = Vals.size();
|
2790 | 2783 | SmallVector<Value *> InterleavingValues(Vals);
|
2791 |
| - // The total number of nodes in a balanced binary tree is calculated as 2n - |
2792 |
| - // 1, where `n` is the number of leaf nodes (`InterleaveFactor`). In this |
2793 |
| - // context, we exclude the root node because it will serve as the final |
2794 |
| - // interleaved value. Thus, the number of nodes to be processed/interleaved |
2795 |
| - // is: (2n - 1) - 1 = 2n - 2. |
2796 |
| - |
2797 |
| - unsigned NumInterleavingValues = 2 * InterleaveFactor - 2; |
2798 |
| - for (unsigned I = 1; I < NumInterleavingValues; I += 2) { |
2799 |
| - // values that haven't been processed yet: |
2800 |
| - unsigned Remaining = InterleavingValues.size() - I + 1; |
2801 |
| - if (Remaining > 2 && isPowerOf2_32(Remaining)) { |
2802 |
| - |
2803 |
| - // The remaining values form a new level in the interleaving tree. |
2804 |
| - // Arrange these values in the correct interleaving order for this |
2805 |
| - // level. The interleaving order places alternating elements from the |
2806 |
| - // first and second halves, |
2807 |
| - std::vector<Value *> RemainingValues(InterleavingValues.begin() + I - 1, |
2808 |
| - InterleavingValues.end()); |
2809 |
| - unsigned Middle = Remaining / 2; |
2810 |
| - for (unsigned J = I - 1, K = 0; J < InterleavingValues.size(); |
2811 |
| - J += 2, K++) { |
2812 |
| - InterleavingValues[J] = RemainingValues[K]; |
2813 |
| - InterleavingValues[J + 1] = RemainingValues[Middle + K]; |
2814 |
| - } |
2815 |
| - } |
| 2784 | + // As we are interleaving, the values sz will be shrinked until we have the |
| 2785 | + // single final interleaved value. |
| 2786 | + for (unsigned Midpoint = Factor / 2; Midpoint > 0; Midpoint /= 2) { |
2816 | 2787 | VectorType *InterleaveTy =
|
2817 |
| - cast<VectorType>(InterleavingValues[I]->getType()); |
| 2788 | + cast<VectorType>(InterleavingValues[0]->getType()); |
2818 | 2789 | VectorType *WideVecTy =
|
2819 | 2790 | VectorType::getDoubleElementsVectorType(InterleaveTy);
|
2820 |
| - auto *InterleaveRes = Builder.CreateIntrinsic( |
2821 |
| - WideVecTy, Intrinsic::vector_interleave2, |
2822 |
| - {InterleavingValues[I - 1], InterleavingValues[I]}, |
2823 |
| - /*FMFSource=*/nullptr, Name); |
2824 |
| - InterleavingValues.push_back(InterleaveRes); |
| 2791 | + for (unsigned I = 0; I < Midpoint; ++I) |
| 2792 | + InterleavingValues[I] = Builder.CreateIntrinsic( |
| 2793 | + WideVecTy, Intrinsic::vector_interleave2, |
| 2794 | + {InterleavingValues[I], InterleavingValues[Midpoint + I]}, |
| 2795 | + /*FMFSource=*/nullptr, Name); |
2825 | 2796 | }
|
2826 |
| - return InterleavingValues[NumInterleavingValues]; |
| 2797 | + return InterleavingValues[0]; |
2827 | 2798 | }
|
2828 | 2799 |
|
2829 | 2800 | // Fixed length. Start by concatenating all vectors into a wide vector.
|
@@ -2960,49 +2931,31 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
|
2960 | 2931 | // Scalable vectors cannot use arbitrary shufflevectors (only splats),
|
2961 | 2932 | // so must use intrinsics to deinterleave.
|
2962 | 2933 |
|
2963 |
| - SmallVector<Value *> DeinterleavedValues; |
2964 |
| - // If the InterleaveFactor is > 2, so we will have to do recursive |
| 2934 | + SmallVector<Value *> DeinterleavedValues(InterleaveFactor); |
| 2935 | + DeinterleavedValues[0] = NewLoad; |
| 2936 | + // For the case of InterleaveFactor > 2, we will have to do recursive |
2965 | 2937 | // deinterleaving, because the current available deinterleave intrinsic
|
2966 |
| - // supports only Factor of 2. DeinterleaveCount represent how many times |
2967 |
| - // we will do deinterleaving, we will do deinterleave on all nonleaf |
2968 |
| - // nodes in the deinterleave tree. |
2969 |
| - unsigned DeinterleaveCount = InterleaveFactor - 1; |
2970 |
| - std::vector<Value *> TempDeinterleavedValues; |
2971 |
| - TempDeinterleavedValues.push_back(NewLoad); |
2972 |
| - for (unsigned I = 0; I < DeinterleaveCount; ++I) { |
2973 |
| - auto *DiTy = TempDeinterleavedValues[I]->getType(); |
2974 |
| - Value *DI = State.Builder.CreateIntrinsic( |
2975 |
| - Intrinsic::vector_deinterleave2, DiTy, TempDeinterleavedValues[I], |
2976 |
| - /*FMFSource=*/nullptr, "strided.vec"); |
2977 |
| - Value *StridedVec = State.Builder.CreateExtractValue(DI, 0); |
2978 |
| - TempDeinterleavedValues.push_back(StridedVec); |
2979 |
| - StridedVec = State.Builder.CreateExtractValue(DI, 1); |
2980 |
| - TempDeinterleavedValues.push_back(StridedVec); |
2981 |
| - // Perform sorting at the start of each new level in the tree. |
2982 |
| - // A new level begins when the number of remaining values is a power |
2983 |
| - // of 2 and greater than 2. If a level has only 2 nodes, no sorting is |
2984 |
| - // needed as they are already in order. Number of remaining values to |
2985 |
| - // be processed: |
2986 |
| - unsigned NumRemainingValues = TempDeinterleavedValues.size() - I - 1; |
2987 |
| - if (NumRemainingValues > 2 && isPowerOf2_32(NumRemainingValues)) { |
2988 |
| - // these remaining values represent a new level in the tree, |
2989 |
| - // Reorder the values to match the correct deinterleaving order. |
2990 |
| - std::vector<Value *> RemainingValues( |
2991 |
| - TempDeinterleavedValues.begin() + I + 1, |
2992 |
| - TempDeinterleavedValues.end()); |
2993 |
| - unsigned Middle = NumRemainingValues / 2; |
2994 |
| - for (unsigned J = 0, K = I + 1; J < NumRemainingValues; |
2995 |
| - J += 2, K++) { |
2996 |
| - TempDeinterleavedValues[K] = RemainingValues[J]; |
2997 |
| - TempDeinterleavedValues[Middle + K] = RemainingValues[J + 1]; |
2998 |
| - } |
| 2938 | + // supports only Factor of 2, otherwise it will bailout after first |
| 2939 | + // iteration. |
| 2940 | + // As we are deinterleaving, the values will be doubled until reachingt |
| 2941 | + // to the InterleaveFactor. |
| 2942 | + for (int NumVectors = 1; NumVectors < InterleaveFactor; |
| 2943 | + NumVectors *= 2) { |
| 2944 | + // deinterleave the elements within the vector |
| 2945 | + std::vector<Value *> TempDeinterleavedValues(NumVectors); |
| 2946 | + for (int I = 0; I < NumVectors; ++I) { |
| 2947 | + auto *DiTy = DeinterleavedValues[I]->getType(); |
| 2948 | + TempDeinterleavedValues[I] = State.Builder.CreateIntrinsic( |
| 2949 | + Intrinsic::vector_deinterleave2, DiTy, DeinterleavedValues[I], |
| 2950 | + /*FMFSource=*/nullptr, "strided.vec"); |
2999 | 2951 | }
|
| 2952 | + // Extract the deinterleaved values: |
| 2953 | + for (int I = 0; I < 2; ++I) |
| 2954 | + for (int J = 0; J < NumVectors; ++J) |
| 2955 | + DeinterleavedValues[NumVectors * I + J] = |
| 2956 | + State.Builder.CreateExtractValue(TempDeinterleavedValues[J], |
| 2957 | + I); |
3000 | 2958 | }
|
3001 |
| - // Final deinterleaved values: |
3002 |
| - DeinterleavedValues.insert(DeinterleavedValues.begin(), |
3003 |
| - TempDeinterleavedValues.begin() + |
3004 |
| - InterleaveFactor - 1, |
3005 |
| - TempDeinterleavedValues.end()); |
3006 | 2959 |
|
3007 | 2960 | #ifndef NDEBUG
|
3008 | 2961 | for (Value *Val : DeinterleavedValues)
|
|
0 commit comments