Skip to content

Commit 94d537b

Browse files
committed
Fix assmuption of the extraction order, make it generic then make sure of the order using pattern match
Change-Id: I053e47d156c37cf4d7ab5b2af83c348b4210631a
1 parent 68e9c30 commit 94d537b

File tree

1 file changed

+66
-50
lines changed

1 file changed

+66
-50
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 66 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16906,68 +16906,85 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
1690616906
return true;
1690716907
}
1690816908

16909-
bool getDeinterleavedValues(Value *DI,
16910-
SmallVectorImpl<Value *> &DeinterleavedValues,
16911-
SmallVectorImpl<Instruction *> &DeadInsts) {
16912-
if (!DI->hasNUses(2))
16909+
bool getDeinterleavedValues(
16910+
Value *DI, SmallVectorImpl<Instruction *> &DeinterleavedValues) {
16911+
if (!DI->hasNUsesOrMore(2))
1691316912
return false;
16914-
16915-
// make sure that the users of DI are extractValue instructions
16916-
auto *Extr0 = *(++DI->user_begin());
16917-
if (!match(Extr0, m_ExtractValue<0>(m_Deinterleave2(m_Value()))))
16918-
return false;
16919-
auto *Extr1 = *(DI->user_begin());
16920-
if (!match(Extr1, m_ExtractValue<1>(m_Deinterleave2(m_Value()))))
16913+
auto *Extr1 = dyn_cast<ExtractValueInst>(*(DI->user_begin()));
16914+
auto *Extr2 = dyn_cast<ExtractValueInst>(*(++DI->user_begin()));
16915+
if (!Extr1 || !Extr2)
1692116916
return false;
1692216917

16923-
// each extractValue instruction is expected to have a single user,
16924-
// which should be another DI
16925-
if (!Extr0->hasOneUser() || !Extr1->hasOneUser())
16918+
if (!Extr1->hasNUsesOrMore(1) || !Extr2->hasNUsesOrMore(1))
1692616919
return false;
16927-
auto *DI1 = *(Extr0->user_begin());
16928-
if (!match(DI1, m_Deinterleave2(m_Value())))
16920+
auto *DI1 = *(Extr1->user_begin());
16921+
auto *DI2 = *(Extr2->user_begin());
16922+
16923+
if (!DI1->hasNUsesOrMore(2) || !DI2->hasNUsesOrMore(2))
1692916924
return false;
16930-
auto *DI2 = *(Extr1->user_begin());
16931-
if (!match(DI2, m_Deinterleave2(m_Value())))
16925+
// Leaf nodes of the deinterleave tree:
16926+
auto *A = dyn_cast<ExtractValueInst>(*(DI1->user_begin()));
16927+
auto *B = dyn_cast<ExtractValueInst>(*(++DI1->user_begin()));
16928+
auto *C = dyn_cast<ExtractValueInst>(*(DI2->user_begin()));
16929+
auto *D = dyn_cast<ExtractValueInst>(*(++DI2->user_begin()));
16930+
// Make sure that the A,B,C,D are instructions of ExtractValue,
16931+
// before getting the extract index
16932+
if (!A || !B || !C || !D)
1693216933
return false;
1693316934

16934-
if (!DI1->hasNUses(2) || !DI2->hasNUses(2))
16935+
DeinterleavedValues.resize(4);
16936+
// Place the values into the vector in the order of extraction:
16937+
DeinterleavedValues[A->getIndices()[0] + (Extr1->getIndices()[0] * 2)] = A;
16938+
DeinterleavedValues[B->getIndices()[0] + (Extr1->getIndices()[0] * 2)] = B;
16939+
DeinterleavedValues[C->getIndices()[0] + (Extr2->getIndices()[0] * 2)] = C;
16940+
DeinterleavedValues[D->getIndices()[0] + (Extr2->getIndices()[0] * 2)] = D;
16941+
16942+
// Make sure that A,B,C,D match the deinterleave tree pattern
16943+
if (!match(DeinterleavedValues[0],
16944+
m_ExtractValue<0>(m_Deinterleave2(
16945+
m_ExtractValue<0>(m_Deinterleave2(m_Value()))))) ||
16946+
!match(DeinterleavedValues[1],
16947+
m_ExtractValue<1>(m_Deinterleave2(
16948+
m_ExtractValue<0>(m_Deinterleave2(m_Value()))))) ||
16949+
!match(DeinterleavedValues[2],
16950+
m_ExtractValue<0>(m_Deinterleave2(
16951+
m_ExtractValue<1>(m_Deinterleave2(m_Value()))))) ||
16952+
!match(DeinterleavedValues[3],
16953+
m_ExtractValue<1>(m_Deinterleave2(
16954+
m_ExtractValue<1>(m_Deinterleave2(m_Value())))))) {
16955+
LLVM_DEBUG(dbgs() << "matching deinterleave4 failed\n");
1693516956
return false;
16936-
16937-
// Leaf nodes of the deinterleave tree
16938-
auto *A = *(++DI1->user_begin());
16939-
auto *C = *(DI1->user_begin());
16940-
auto *B = *(++DI2->user_begin());
16941-
auto *D = *(DI2->user_begin());
16942-
16943-
DeinterleavedValues.push_back(A);
16944-
DeinterleavedValues.push_back(B);
16945-
DeinterleavedValues.push_back(C);
16946-
DeinterleavedValues.push_back(D);
16947-
16948-
// These Values will not be used anymre,
16949-
// DI4 will be created instead of nested DI1 and DI2
16950-
DeadInsts.push_back(cast<Instruction>(DI1));
16951-
DeadInsts.push_back(cast<Instruction>(Extr0));
16952-
DeadInsts.push_back(cast<Instruction>(DI2));
16953-
DeadInsts.push_back(cast<Instruction>(Extr1));
16954-
16957+
}
16958+
// Order the values according to the deinterleaving order.
16959+
std::swap(DeinterleavedValues[1], DeinterleavedValues[2]);
1695516960
return true;
1695616961
}
1695716962

16963+
void deleteDeadDeinterleaveInstructions(Instruction *DeadRoot) {
16964+
Value *DeadDeinterleave = nullptr, *DeadExtract = nullptr;
16965+
match(DeadRoot, m_ExtractValue(m_Value(DeadDeinterleave)));
16966+
assert(DeadDeinterleave != nullptr && "Match is expected to succeed");
16967+
match(DeadDeinterleave, m_Deinterleave2(m_Value(DeadExtract)));
16968+
assert(DeadExtract != nullptr && "Match is expected to succeed");
16969+
DeadRoot->eraseFromParent();
16970+
if (DeadDeinterleave->getNumUses() == 0)
16971+
cast<Instruction>(DeadDeinterleave)->eraseFromParent();
16972+
if (DeadExtract->getNumUses() == 0)
16973+
cast<Instruction>(DeadExtract)->eraseFromParent();
16974+
}
16975+
1695816976
bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
1695916977
IntrinsicInst *DI, LoadInst *LI) const {
1696016978
// Only deinterleave2 supported at present.
1696116979
if (DI->getIntrinsicID() != Intrinsic::vector_deinterleave2)
1696216980
return false;
1696316981

16964-
SmallVector<Value *, 4> DeinterleavedValues;
16965-
SmallVector<Instruction *, 10> DeadInsts;
16982+
SmallVector<Instruction *, 4> DeinterleavedValues;
1696616983
const DataLayout &DL = DI->getModule()->getDataLayout();
1696716984
unsigned Factor = 2;
1696816985
VectorType *VTy = cast<VectorType>(DI->getType()->getContainedType(0));
1696916986

16970-
if (getDeinterleavedValues(DI, DeinterleavedValues, DeadInsts)) {
16987+
if (getDeinterleavedValues(DI, DeinterleavedValues)) {
1697116988
Factor = DeinterleavedValues.size();
1697216989
VTy = cast<VectorType>(DeinterleavedValues[0]->getType());
1697316990
}
@@ -17014,7 +17031,7 @@ bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
1701417031
LdN = Builder.CreateCall(LdNFunc, Address, "ldN");
1701517032
Value *Idx =
1701617033
Builder.getInt64(I * LdTy->getElementCount().getKnownMinValue());
17017-
for (int J = 0; J < Factor; ++J) {
17034+
for (unsigned J = 0; J < Factor; ++J) {
1701817035
WideValues[J] = Builder.CreateInsertVector(
1701917036
VTy, WideValues[J], Builder.CreateExtractValue(LdN, J), Idx);
1702017037
}
@@ -17024,7 +17041,7 @@ bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
1702417041
else
1702517042
Result = PoisonValue::get(StructType::get(VTy, VTy, VTy, VTy));
1702617043
// Construct the wide result out of the small results.
17027-
for (int J = 0; J < Factor; ++J) {
17044+
for (unsigned J = 0; J < Factor; ++J) {
1702817045
Result = Builder.CreateInsertValue(Result, WideValues[J], J);
1702917046
}
1703017047
} else {
@@ -17034,15 +17051,14 @@ bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
1703417051
Result = Builder.CreateCall(LdNFunc, BaseAddr, "ldN");
1703517052
}
1703617053
if (Factor > 2) {
17054+
// Itereate over old deinterleaved values to replace it by
17055+
// the new deinterleaved values.
1703717056
for (unsigned I = 0; I < DeinterleavedValues.size(); I++) {
17038-
llvm::Value *CurrentExtract = DeinterleavedValues[I];
1703917057
Value *NewExtract = Builder.CreateExtractValue(Result, I);
17040-
CurrentExtract->replaceAllUsesWith(NewExtract);
17041-
cast<Instruction>(CurrentExtract)->eraseFromParent();
17058+
DeinterleavedValues[I]->replaceAllUsesWith(NewExtract);
1704217059
}
17043-
17044-
for (auto &dead : DeadInsts)
17045-
dead->eraseFromParent();
17060+
for (unsigned I = 0; I < DeinterleavedValues.size(); I++)
17061+
deleteDeadDeinterleaveInstructions(DeinterleavedValues[I]);
1704617062
return true;
1704717063
}
1704817064
DI->replaceAllUsesWith(Result);
@@ -17124,7 +17140,7 @@ bool AArch64TargetLowering::lowerInterleaveIntrinsicToStore(
1712417140
Address = Builder.CreateGEP(StTy, BaseAddr, {Offset});
1712517141
Value *Idx =
1712617142
Builder.getInt64(I * StTy->getElementCount().getKnownMinValue());
17127-
for (int J = 0; J < Factor; J++) {
17143+
for (unsigned J = 0; J < Factor; J++) {
1712817144
ValuesToInterleave[J] =
1712917145
Builder.CreateExtractVector(StTy, WideValues[J], Idx);
1713017146
}

0 commit comments

Comments
 (0)