Skip to content

Commit d0f3ae0

Browse files
committed
[resolve review comments] do multiple ordering during (de)interleaving to put the nodes in the correct order for (de)interleaving.
Change-Id: I151a53c459a7f69e35feb428c1dface2fe57e9ce
1 parent 5afaab4 commit d0f3ae0

File tree

1 file changed

+63
-46
lines changed

1 file changed

+63
-46
lines changed

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 63 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2780,28 +2780,39 @@ static Value *interleaveVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vals,
27802780
// Scalable vectors cannot use arbitrary shufflevectors (only splats), so
27812781
// must use intrinsics to interleave.
27822782
if (VecTy->isScalableTy()) {
2783+
if (Vals.size() == 2) {
2784+
VectorType *WideVecTy = VectorType::getDoubleElementsVectorType(VecTy);
2785+
return Builder.CreateIntrinsic(WideVecTy, Intrinsic::vector_interleave2,
2786+
Vals,
2787+
/*FMFSource=*/nullptr, Name);
2788+
}
27832789
unsigned InterleaveFactor = Vals.size();
2784-
SmallVector<Value *> InterleavingValues;
2785-
unsigned InterleavingValuesCount =
2786-
InterleaveFactor + (InterleaveFactor - 2);
2787-
InterleavingValues.resize(InterleaveFactor);
2788-
// Place the values to be interleaved in the correct order for the
2789-
// interleaving
2790-
for (unsigned I = 0, J = InterleaveFactor / 2, K = 0; K < InterleaveFactor;
2791-
K++) {
2792-
if (K % 2 == 0) {
2793-
InterleavingValues[K] = Vals[I];
2794-
I++;
2795-
} else {
2796-
InterleavingValues[K] = Vals[J];
2797-
J++;
2790+
SmallVector<Value *> InterleavingValues(Vals);
2791+
// The total number of nodes in a balanced binary tree is calculated as 2n -
2792+
// 1, where `n` is the number of leaf nodes (`InterleaveFactor`). In this
2793+
// context, we exclude the root node because it will serve as the final
2794+
// interleaved value. Thus, the number of nodes to be processed/interleaved
2795+
// is: (2n - 1) - 1 = 2n - 2.
2796+
2797+
unsigned NumInterleavingValues = 2 * InterleaveFactor - 2;
2798+
for (unsigned I = 1; I < NumInterleavingValues; I += 2) {
2799+
// values that haven't been processed yet:
2800+
unsigned Remaining = InterleavingValues.size() - I + 1;
2801+
if (Remaining > 2 && isPowerOf2_32(Remaining)) {
2802+
2803+
// The remaining values form a new level in the interleaving tree.
2804+
// Arrange these values in the correct interleaving order for this
2805+
// level. The interleaving order places alternating elements from the
2806+
// first and second halves,
2807+
std::vector<Value *> RemainingValues(InterleavingValues.begin() + I - 1,
2808+
InterleavingValues.end());
2809+
unsigned Middle = Remaining / 2;
2810+
for (unsigned J = I - 1, K = 0; J < InterleavingValues.size();
2811+
J += 2, K++) {
2812+
InterleavingValues[J] = RemainingValues[K];
2813+
InterleavingValues[J + 1] = RemainingValues[Middle + K];
2814+
}
27982815
}
2799-
}
2800-
#ifndef NDEBUG
2801-
for (Value *Val : InterleavingValues)
2802-
assert(Val && "NULL Interleaving Value");
2803-
#endif
2804-
for (unsigned I = 1; I < InterleavingValuesCount; I += 2) {
28052816
VectorType *InterleaveTy =
28062817
cast<VectorType>(InterleavingValues[I]->getType());
28072818
VectorType *WideVecTy =
@@ -2812,7 +2823,7 @@ static Value *interleaveVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vals,
28122823
/*FMFSource=*/nullptr, Name);
28132824
InterleavingValues.push_back(InterleaveRes);
28142825
}
2815-
return InterleavingValues[InterleavingValuesCount];
2826+
return InterleavingValues[NumInterleavingValues];
28162827
}
28172828

28182829
// Fixed length. Start by concatenating all vectors into a wide vector.
@@ -2951,42 +2962,48 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
29512962

29522963
SmallVector<Value *> DeinterleavedValues;
29532964
// If the InterleaveFactor is > 2, so we will have to do recursive
2954-
// deinterleaving, because the current available deinterleave intrinsice
2965+
// deinterleaving, because the current available deinterleave intrinsic
29552966
// supports only Factor of 2. DeinterleaveCount represent how many times
29562967
// we will do deinterleaving, we will do deinterleave on all nonleaf
29572968
// nodes in the deinterleave tree.
29582969
unsigned DeinterleaveCount = InterleaveFactor - 1;
2959-
std::queue<Value *> TempDeinterleavedValues;
2960-
TempDeinterleavedValues.push(NewLoad);
2970+
std::vector<Value *> TempDeinterleavedValues;
2971+
TempDeinterleavedValues.push_back(NewLoad);
29612972
for (unsigned I = 0; I < DeinterleaveCount; ++I) {
2962-
Value *ValueToDeinterleave = TempDeinterleavedValues.front();
2963-
auto *DiTy = ValueToDeinterleave->getType();
2964-
TempDeinterleavedValues.pop();
2973+
auto *DiTy = TempDeinterleavedValues[I]->getType();
29652974
Value *DI = State.Builder.CreateIntrinsic(
2966-
Intrinsic::vector_deinterleave2, DiTy, ValueToDeinterleave,
2975+
Intrinsic::vector_deinterleave2, DiTy, TempDeinterleavedValues[I],
29672976
/*FMFSource=*/nullptr, "strided.vec");
29682977
Value *StridedVec = State.Builder.CreateExtractValue(DI, 0);
2969-
TempDeinterleavedValues.push(StridedVec);
2978+
TempDeinterleavedValues.push_back(StridedVec);
29702979
StridedVec = State.Builder.CreateExtractValue(DI, 1);
2971-
TempDeinterleavedValues.push(StridedVec);
2972-
}
2973-
2974-
assert(TempDeinterleavedValues.size() == InterleaveFactor &&
2975-
"Num of deinterleaved values must equals to InterleaveFactor");
2976-
// Sort deinterleaved values
2977-
DeinterleavedValues.resize(InterleaveFactor);
2978-
for (unsigned I = 0, J = InterleaveFactor / 2, K = 0;
2979-
K < InterleaveFactor; K++) {
2980-
auto *DeinterleavedValue = TempDeinterleavedValues.front();
2981-
TempDeinterleavedValues.pop();
2982-
if (K % 2 == 0) {
2983-
DeinterleavedValues[I] = DeinterleavedValue;
2984-
I++;
2985-
} else {
2986-
DeinterleavedValues[J] = DeinterleavedValue;
2987-
J++;
2980+
TempDeinterleavedValues.push_back(StridedVec);
2981+
// Perform sorting at the start of each new level in the tree.
2982+
// A new level begins when the number of remaining values is a power
2983+
// of 2 and greater than 2. If a level has only 2 nodes, no sorting is
2984+
// needed as they are already in order. Number of remaining values to
2985+
// be processed:
2986+
unsigned NumRemainingValues = TempDeinterleavedValues.size() - I - 1;
2987+
if (NumRemainingValues > 2 && isPowerOf2_32(NumRemainingValues)) {
2988+
// these remaining values represent a new level in the tree,
2989+
// Reorder the values to match the correct deinterleaving order.
2990+
std::vector<Value *> RemainingValues(
2991+
TempDeinterleavedValues.begin() + I + 1,
2992+
TempDeinterleavedValues.end());
2993+
unsigned Middle = NumRemainingValues / 2;
2994+
for (unsigned J = 0, K = I + 1; J < NumRemainingValues;
2995+
J += 2, K++) {
2996+
TempDeinterleavedValues[K] = RemainingValues[J];
2997+
TempDeinterleavedValues[Middle + K] = RemainingValues[J + 1];
2998+
}
29882999
}
29893000
}
3001+
// Final deinterleaved values:
3002+
DeinterleavedValues.insert(DeinterleavedValues.begin(),
3003+
TempDeinterleavedValues.begin() +
3004+
InterleaveFactor - 1,
3005+
TempDeinterleavedValues.end());
3006+
29903007
#ifndef NDEBUG
29913008
for (Value *Val : DeinterleavedValues)
29923009
assert(Val && "NULL Deinterleaved Value");

0 commit comments

Comments
 (0)