@@ -16906,68 +16906,85 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
16906
16906
return true;
16907
16907
}
16908
16908
16909
- bool getDeinterleavedValues(Value *DI,
16910
- SmallVectorImpl<Value *> &DeinterleavedValues,
16911
- SmallVectorImpl<Instruction *> &DeadInsts) {
16912
- if (!DI->hasNUses(2))
16909
+ bool getDeinterleavedValues(
16910
+ Value *DI, SmallVectorImpl<Instruction *> &DeinterleavedValues) {
16911
+ if (!DI->hasNUsesOrMore(2))
16913
16912
return false;
16914
-
16915
- // make sure that the users of DI are extractValue instructions
16916
- auto *Extr0 = *(++DI->user_begin());
16917
- if (!match(Extr0, m_ExtractValue<0>(m_Deinterleave2(m_Value()))))
16918
- return false;
16919
- auto *Extr1 = *(DI->user_begin());
16920
- if (!match(Extr1, m_ExtractValue<1>(m_Deinterleave2(m_Value()))))
16913
+ auto *Extr1 = dyn_cast<ExtractValueInst>(*(DI->user_begin()));
16914
+ auto *Extr2 = dyn_cast<ExtractValueInst>(*(++DI->user_begin()));
16915
+ if (!Extr1 || !Extr2)
16921
16916
return false;
16922
16917
16923
- // each extractValue instruction is expected to have a single user,
16924
- // which should be another DI
16925
- if (!Extr0->hasOneUser() || !Extr1->hasOneUser())
16918
+ if (!Extr1->hasNUsesOrMore(1) || !Extr2->hasNUsesOrMore(1))
16926
16919
return false;
16927
- auto *DI1 = *(Extr0->user_begin());
16928
- if (!match(DI1, m_Deinterleave2(m_Value())))
16920
+ auto *DI1 = *(Extr1->user_begin());
16921
+ auto *DI2 = *(Extr2->user_begin());
16922
+
16923
+ if (!DI1->hasNUsesOrMore(2) || !DI2->hasNUsesOrMore(2))
16929
16924
return false;
16930
- auto *DI2 = *(Extr1->user_begin());
16931
- if (!match(DI2, m_Deinterleave2(m_Value())))
16925
+ // Leaf nodes of the deinterleave tree:
16926
+ auto *A = dyn_cast<ExtractValueInst>(*(DI1->user_begin()));
16927
+ auto *B = dyn_cast<ExtractValueInst>(*(++DI1->user_begin()));
16928
+ auto *C = dyn_cast<ExtractValueInst>(*(DI2->user_begin()));
16929
+ auto *D = dyn_cast<ExtractValueInst>(*(++DI2->user_begin()));
16930
+ // Make sure that the A,B,C,D are instructions of ExtractValue,
16931
+ // before getting the extract index
16932
+ if (!A || !B || !C || !D)
16932
16933
return false;
16933
16934
16934
- if (!DI1->hasNUses(2) || !DI2->hasNUses(2))
16935
+ DeinterleavedValues.resize(4);
16936
+ // Place the values into the vector in the order of extraction:
16937
+ DeinterleavedValues[A->getIndices()[0] + (Extr1->getIndices()[0] * 2)] = A;
16938
+ DeinterleavedValues[B->getIndices()[0] + (Extr1->getIndices()[0] * 2)] = B;
16939
+ DeinterleavedValues[C->getIndices()[0] + (Extr2->getIndices()[0] * 2)] = C;
16940
+ DeinterleavedValues[D->getIndices()[0] + (Extr2->getIndices()[0] * 2)] = D;
16941
+
16942
+ // Make sure that A,B,C,D match the deinterleave tree pattern
16943
+ if (!match(DeinterleavedValues[0],
16944
+ m_ExtractValue<0>(m_Deinterleave2(
16945
+ m_ExtractValue<0>(m_Deinterleave2(m_Value()))))) ||
16946
+ !match(DeinterleavedValues[1],
16947
+ m_ExtractValue<1>(m_Deinterleave2(
16948
+ m_ExtractValue<0>(m_Deinterleave2(m_Value()))))) ||
16949
+ !match(DeinterleavedValues[2],
16950
+ m_ExtractValue<0>(m_Deinterleave2(
16951
+ m_ExtractValue<1>(m_Deinterleave2(m_Value()))))) ||
16952
+ !match(DeinterleavedValues[3],
16953
+ m_ExtractValue<1>(m_Deinterleave2(
16954
+ m_ExtractValue<1>(m_Deinterleave2(m_Value())))))) {
16955
+ LLVM_DEBUG(dbgs() << "matching deinterleave4 failed\n");
16935
16956
return false;
16936
-
16937
- // Leaf nodes of the deinterleave tree
16938
- auto *A = *(++DI1->user_begin());
16939
- auto *C = *(DI1->user_begin());
16940
- auto *B = *(++DI2->user_begin());
16941
- auto *D = *(DI2->user_begin());
16942
-
16943
- DeinterleavedValues.push_back(A);
16944
- DeinterleavedValues.push_back(B);
16945
- DeinterleavedValues.push_back(C);
16946
- DeinterleavedValues.push_back(D);
16947
-
16948
- // These Values will not be used anymre,
16949
- // DI4 will be created instead of nested DI1 and DI2
16950
- DeadInsts.push_back(cast<Instruction>(DI1));
16951
- DeadInsts.push_back(cast<Instruction>(Extr0));
16952
- DeadInsts.push_back(cast<Instruction>(DI2));
16953
- DeadInsts.push_back(cast<Instruction>(Extr1));
16954
-
16957
+ }
16958
+ // Order the values according to the deinterleaving order.
16959
+ std::swap(DeinterleavedValues[1], DeinterleavedValues[2]);
16955
16960
return true;
16956
16961
}
16957
16962
16963
+ void deleteDeadDeinterleaveInstructions(Instruction *DeadRoot) {
16964
+ Value *DeadDeinterleave = nullptr, *DeadExtract = nullptr;
16965
+ match(DeadRoot, m_ExtractValue(m_Value(DeadDeinterleave)));
16966
+ assert(DeadDeinterleave != nullptr && "Match is expected to succeed");
16967
+ match(DeadDeinterleave, m_Deinterleave2(m_Value(DeadExtract)));
16968
+ assert(DeadExtract != nullptr && "Match is expected to succeed");
16969
+ DeadRoot->eraseFromParent();
16970
+ if (DeadDeinterleave->getNumUses() == 0)
16971
+ cast<Instruction>(DeadDeinterleave)->eraseFromParent();
16972
+ if (DeadExtract->getNumUses() == 0)
16973
+ cast<Instruction>(DeadExtract)->eraseFromParent();
16974
+ }
16975
+
16958
16976
bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
16959
16977
IntrinsicInst *DI, LoadInst *LI) const {
16960
16978
// Only deinterleave2 supported at present.
16961
16979
if (DI->getIntrinsicID() != Intrinsic::vector_deinterleave2)
16962
16980
return false;
16963
16981
16964
- SmallVector<Value *, 4> DeinterleavedValues;
16965
- SmallVector<Instruction *, 10> DeadInsts;
16982
+ SmallVector<Instruction *, 4> DeinterleavedValues;
16966
16983
const DataLayout &DL = DI->getModule()->getDataLayout();
16967
16984
unsigned Factor = 2;
16968
16985
VectorType *VTy = cast<VectorType>(DI->getType()->getContainedType(0));
16969
16986
16970
- if (getDeinterleavedValues(DI, DeinterleavedValues, DeadInsts )) {
16987
+ if (getDeinterleavedValues(DI, DeinterleavedValues)) {
16971
16988
Factor = DeinterleavedValues.size();
16972
16989
VTy = cast<VectorType>(DeinterleavedValues[0]->getType());
16973
16990
}
@@ -17014,7 +17031,7 @@ bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
17014
17031
LdN = Builder.CreateCall(LdNFunc, Address, "ldN");
17015
17032
Value *Idx =
17016
17033
Builder.getInt64(I * LdTy->getElementCount().getKnownMinValue());
17017
- for (int J = 0; J < Factor; ++J) {
17034
+ for (unsigned J = 0; J < Factor; ++J) {
17018
17035
WideValues[J] = Builder.CreateInsertVector(
17019
17036
VTy, WideValues[J], Builder.CreateExtractValue(LdN, J), Idx);
17020
17037
}
@@ -17024,7 +17041,7 @@ bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
17024
17041
else
17025
17042
Result = PoisonValue::get(StructType::get(VTy, VTy, VTy, VTy));
17026
17043
// Construct the wide result out of the small results.
17027
- for (int J = 0; J < Factor; ++J) {
17044
+ for (unsigned J = 0; J < Factor; ++J) {
17028
17045
Result = Builder.CreateInsertValue(Result, WideValues[J], J);
17029
17046
}
17030
17047
} else {
@@ -17034,15 +17051,14 @@ bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
17034
17051
Result = Builder.CreateCall(LdNFunc, BaseAddr, "ldN");
17035
17052
}
17036
17053
if (Factor > 2) {
17054
+ // Itereate over old deinterleaved values to replace it by
17055
+ // the new deinterleaved values.
17037
17056
for (unsigned I = 0; I < DeinterleavedValues.size(); I++) {
17038
- llvm::Value *CurrentExtract = DeinterleavedValues[I];
17039
17057
Value *NewExtract = Builder.CreateExtractValue(Result, I);
17040
- CurrentExtract->replaceAllUsesWith(NewExtract);
17041
- cast<Instruction>(CurrentExtract)->eraseFromParent();
17058
+ DeinterleavedValues[I]->replaceAllUsesWith(NewExtract);
17042
17059
}
17043
-
17044
- for (auto &dead : DeadInsts)
17045
- dead->eraseFromParent();
17060
+ for (unsigned I = 0; I < DeinterleavedValues.size(); I++)
17061
+ deleteDeadDeinterleaveInstructions(DeinterleavedValues[I]);
17046
17062
return true;
17047
17063
}
17048
17064
DI->replaceAllUsesWith(Result);
@@ -17124,7 +17140,7 @@ bool AArch64TargetLowering::lowerInterleaveIntrinsicToStore(
17124
17140
Address = Builder.CreateGEP(StTy, BaseAddr, {Offset});
17125
17141
Value *Idx =
17126
17142
Builder.getInt64(I * StTy->getElementCount().getKnownMinValue());
17127
- for (int J = 0; J < Factor; J++) {
17143
+ for (unsigned J = 0; J < Factor; J++) {
17128
17144
ValuesToInterleave[J] =
17129
17145
Builder.CreateExtractVector(StTy, WideValues[J], Idx);
17130
17146
}
0 commit comments