Skip to content

Commit 08f0cda

Browse files
committed
refactoring
Change-Id: If2a3789ed76c98a5f1d1be729f5051a2c54af2a7
1 parent d0f3ae0 commit 08f0cda

File tree

1 file changed

+32
-79
lines changed

1 file changed

+32
-79
lines changed

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 32 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
#include "llvm/Transforms/Utils/LoopUtils.h"
3636
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
3737
#include <cassert>
38-
#include <queue>
3938

4039
using namespace llvm;
4140

@@ -2780,50 +2779,22 @@ static Value *interleaveVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vals,
27802779
// Scalable vectors cannot use arbitrary shufflevectors (only splats), so
27812780
// must use intrinsics to interleave.
27822781
if (VecTy->isScalableTy()) {
2783-
if (Vals.size() == 2) {
2784-
VectorType *WideVecTy = VectorType::getDoubleElementsVectorType(VecTy);
2785-
return Builder.CreateIntrinsic(WideVecTy, Intrinsic::vector_interleave2,
2786-
Vals,
2787-
/*FMFSource=*/nullptr, Name);
2788-
}
27892782
unsigned InterleaveFactor = Vals.size();
27902783
SmallVector<Value *> InterleavingValues(Vals);
2791-
// The total number of nodes in a balanced binary tree is calculated as 2n -
2792-
// 1, where `n` is the number of leaf nodes (`InterleaveFactor`). In this
2793-
// context, we exclude the root node because it will serve as the final
2794-
// interleaved value. Thus, the number of nodes to be processed/interleaved
2795-
// is: (2n - 1) - 1 = 2n - 2.
2796-
2797-
unsigned NumInterleavingValues = 2 * InterleaveFactor - 2;
2798-
for (unsigned I = 1; I < NumInterleavingValues; I += 2) {
2799-
// values that haven't been processed yet:
2800-
unsigned Remaining = InterleavingValues.size() - I + 1;
2801-
if (Remaining > 2 && isPowerOf2_32(Remaining)) {
2802-
2803-
// The remaining values form a new level in the interleaving tree.
2804-
// Arrange these values in the correct interleaving order for this
2805-
// level. The interleaving order places alternating elements from the
2806-
// first and second halves,
2807-
std::vector<Value *> RemainingValues(InterleavingValues.begin() + I - 1,
2808-
InterleavingValues.end());
2809-
unsigned Middle = Remaining / 2;
2810-
for (unsigned J = I - 1, K = 0; J < InterleavingValues.size();
2811-
J += 2, K++) {
2812-
InterleavingValues[J] = RemainingValues[K];
2813-
InterleavingValues[J + 1] = RemainingValues[Middle + K];
2814-
}
2815-
}
2784+
// As we are interleaving, the values sz will be shrinked until we have the
2785+
// single final interleaved value.
2786+
for (unsigned Midpoint = Factor / 2; Midpoint > 0; Midpoint /= 2) {
28162787
VectorType *InterleaveTy =
2817-
cast<VectorType>(InterleavingValues[I]->getType());
2788+
cast<VectorType>(InterleavingValues[0]->getType());
28182789
VectorType *WideVecTy =
28192790
VectorType::getDoubleElementsVectorType(InterleaveTy);
2820-
auto *InterleaveRes = Builder.CreateIntrinsic(
2821-
WideVecTy, Intrinsic::vector_interleave2,
2822-
{InterleavingValues[I - 1], InterleavingValues[I]},
2823-
/*FMFSource=*/nullptr, Name);
2824-
InterleavingValues.push_back(InterleaveRes);
2791+
for (unsigned I = 0; I < Midpoint; ++I)
2792+
InterleavingValues[I] = Builder.CreateIntrinsic(
2793+
WideVecTy, Intrinsic::vector_interleave2,
2794+
{InterleavingValues[I], InterleavingValues[Midpoint + I]},
2795+
/*FMFSource=*/nullptr, Name);
28252796
}
2826-
return InterleavingValues[NumInterleavingValues];
2797+
return InterleavingValues[0];
28272798
}
28282799

28292800
// Fixed length. Start by concatenating all vectors into a wide vector.
@@ -2960,49 +2931,31 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
29602931
// Scalable vectors cannot use arbitrary shufflevectors (only splats),
29612932
// so must use intrinsics to deinterleave.
29622933

2963-
SmallVector<Value *> DeinterleavedValues;
2964-
// If the InterleaveFactor is > 2, so we will have to do recursive
2934+
SmallVector<Value *> DeinterleavedValues(InterleaveFactor);
2935+
DeinterleavedValues[0] = NewLoad;
2936+
// For the case of InterleaveFactor > 2, we will have to do recursive
29652937
// deinterleaving, because the current available deinterleave intrinsic
2966-
// supports only Factor of 2. DeinterleaveCount represent how many times
2967-
// we will do deinterleaving, we will do deinterleave on all nonleaf
2968-
// nodes in the deinterleave tree.
2969-
unsigned DeinterleaveCount = InterleaveFactor - 1;
2970-
std::vector<Value *> TempDeinterleavedValues;
2971-
TempDeinterleavedValues.push_back(NewLoad);
2972-
for (unsigned I = 0; I < DeinterleaveCount; ++I) {
2973-
auto *DiTy = TempDeinterleavedValues[I]->getType();
2974-
Value *DI = State.Builder.CreateIntrinsic(
2975-
Intrinsic::vector_deinterleave2, DiTy, TempDeinterleavedValues[I],
2976-
/*FMFSource=*/nullptr, "strided.vec");
2977-
Value *StridedVec = State.Builder.CreateExtractValue(DI, 0);
2978-
TempDeinterleavedValues.push_back(StridedVec);
2979-
StridedVec = State.Builder.CreateExtractValue(DI, 1);
2980-
TempDeinterleavedValues.push_back(StridedVec);
2981-
// Perform sorting at the start of each new level in the tree.
2982-
// A new level begins when the number of remaining values is a power
2983-
// of 2 and greater than 2. If a level has only 2 nodes, no sorting is
2984-
// needed as they are already in order. Number of remaining values to
2985-
// be processed:
2986-
unsigned NumRemainingValues = TempDeinterleavedValues.size() - I - 1;
2987-
if (NumRemainingValues > 2 && isPowerOf2_32(NumRemainingValues)) {
2988-
// these remaining values represent a new level in the tree,
2989-
// Reorder the values to match the correct deinterleaving order.
2990-
std::vector<Value *> RemainingValues(
2991-
TempDeinterleavedValues.begin() + I + 1,
2992-
TempDeinterleavedValues.end());
2993-
unsigned Middle = NumRemainingValues / 2;
2994-
for (unsigned J = 0, K = I + 1; J < NumRemainingValues;
2995-
J += 2, K++) {
2996-
TempDeinterleavedValues[K] = RemainingValues[J];
2997-
TempDeinterleavedValues[Middle + K] = RemainingValues[J + 1];
2998-
}
2938+
// supports only Factor of 2, otherwise it will bailout after first
2939+
// iteration.
2940+
// As we are deinterleaving, the values will be doubled until reachingt
2941+
// to the InterleaveFactor.
2942+
for (int NumVectors = 1; NumVectors < InterleaveFactor;
2943+
NumVectors *= 2) {
2944+
// deinterleave the elements within the vector
2945+
std::vector<Value *> TempDeinterleavedValues(NumVectors);
2946+
for (int I = 0; I < NumVectors; ++I) {
2947+
auto *DiTy = DeinterleavedValues[I]->getType();
2948+
TempDeinterleavedValues[I] = State.Builder.CreateIntrinsic(
2949+
Intrinsic::vector_deinterleave2, DiTy, DeinterleavedValues[I],
2950+
/*FMFSource=*/nullptr, "strided.vec");
29992951
}
2952+
// Extract the deinterleaved values:
2953+
for (int I = 0; I < 2; ++I)
2954+
for (int J = 0; J < NumVectors; ++J)
2955+
DeinterleavedValues[NumVectors * I + J] =
2956+
State.Builder.CreateExtractValue(TempDeinterleavedValues[J],
2957+
I);
30002958
}
3001-
// Final deinterleaved values:
3002-
DeinterleavedValues.insert(DeinterleavedValues.begin(),
3003-
TempDeinterleavedValues.begin() +
3004-
InterleaveFactor - 1,
3005-
TempDeinterleavedValues.end());
30062959

30072960
#ifndef NDEBUG
30082961
for (Value *Val : DeinterleavedValues)

0 commit comments

Comments
 (0)