Skip to content

Commit 85a17a4

Browse files
committed
[AMDGPU] Accept arbitrary sized sources in CalculateByteProvider
This allows working with e.g. v8i8 / v16i8 sources. It is generally useful, but is primarily beneficial when allowing e.g. v8i8s to be passed to branches directly through registers. As such, this is the first in a series of patches to enable that work. However, it effects https://reviews.llvm.org/D155995, so it has been implemented on top of that. Differential Revision: https://reviews.llvm.org/D159036 Change-Id: Idfcb57dacd0c32cab040fe4dd4ac2ec762750664
1 parent bcd1490 commit 85a17a4

File tree

6 files changed

+1498
-113
lines changed

6 files changed

+1498
-113
lines changed

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 124 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -11567,8 +11567,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
1156711567
if (Depth >= 6)
1156811568
return std::nullopt;
1156911569

11570-
auto ValueSize = Op.getValueSizeInBits();
11571-
if (ValueSize != 8 && ValueSize != 16 && ValueSize != 32)
11570+
if (Op.getValueSizeInBits() < 8)
1157211571
return std::nullopt;
1157311572

1157411573
switch (Op->getOpcode()) {
@@ -11827,8 +11826,6 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1182711826
auto VecIdx = IdxOp->getZExtValue();
1182811827
auto ScalarSize = Op.getScalarValueSizeInBits();
1182911828
if (ScalarSize != 32) {
11830-
if ((VecIdx + 1) * ScalarSize > 32)
11831-
return std::nullopt;
1183211829
Index = ScalarSize == 8 ? VecIdx : VecIdx * 2 + Index;
1183311830
}
1183411831

@@ -11913,9 +11910,6 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
1191311910
int Low16 = PermMask & 0xffff;
1191411911
int Hi16 = (PermMask & 0xffff0000) >> 16;
1191511912

11916-
assert(Op.getValueType().isByteSized());
11917-
assert(OtherOp.getValueType().isByteSized());
11918-
1191911913
auto TempOp = peekThroughBitcasts(Op);
1192011914
auto TempOtherOp = peekThroughBitcasts(OtherOp);
1192111915

@@ -11933,15 +11927,38 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
1193311927
return !addresses16Bits(Low16) || !addresses16Bits(Hi16);
1193411928
}
1193511929

11930+
static SDValue getDWordFromOffset(SelectionDAG &DAG, SDLoc SL, SDValue Src,
11931+
unsigned DWordOffset) {
11932+
SDValue Ret;
11933+
if (Src.getValueSizeInBits() <= 32)
11934+
return DAG.getBitcastedAnyExtOrTrunc(Src, SL, MVT::i32);
11935+
11936+
if (Src.getValueSizeInBits() >= 256) {
11937+
assert(!(Src.getValueSizeInBits() % 32));
11938+
Ret = DAG.getBitcast(
11939+
MVT::getVectorVT(MVT::i32, Src.getValueSizeInBits() / 32), Src);
11940+
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, Ret,
11941+
DAG.getConstant(DWordOffset, SL, MVT::i32));
11942+
}
11943+
11944+
Ret = DAG.getBitcastedAnyExtOrTrunc(
11945+
Src, SL, MVT::getIntegerVT(Src.getValueSizeInBits()));
11946+
if (DWordOffset) {
11947+
auto Shifted = DAG.getNode(ISD::SRL, SL, Ret.getValueType(), Ret,
11948+
DAG.getConstant(DWordOffset * 32, SL, MVT::i32));
11949+
return DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, Shifted);
11950+
}
11951+
11952+
return DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
11953+
}
11954+
1193611955
static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1193711956
SelectionDAG &DAG = DCI.DAG;
1193811957
EVT VT = N->getValueType(0);
11939-
11940-
if (VT != MVT::i32)
11941-
return SDValue();
11958+
SmallVector<ByteProvider<SDValue>, 8> PermNodes;
1194211959

1194311960
// VT is known to be MVT::i32, so we need to provide 4 bytes.
11944-
SmallVector<ByteProvider<SDValue>, 8> PermNodes;
11961+
assert(VT == MVT::i32);
1194511962
for (int i = 0; i < 4; i++) {
1194611963
// Find the ByteProvider that provides the ith byte of the result of OR
1194711964
std::optional<ByteProvider<SDValue>> P =
@@ -11955,42 +11972,40 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1195511972
if (PermNodes.size() != 4)
1195611973
return SDValue();
1195711974

11958-
int FirstSrc = 0;
11959-
std::optional<int> SecondSrc;
11975+
std::pair<unsigned, unsigned> FirstSrc(0, PermNodes[0].SrcOffset / 4);
11976+
std::optional<std::pair<unsigned, unsigned>> SecondSrc;
1196011977
uint64_t PermMask = 0x00000000;
1196111978
for (size_t i = 0; i < PermNodes.size(); i++) {
1196211979
auto PermOp = PermNodes[i];
1196311980
// Since the mask is applied to Src1:Src2, Src1 bytes must be offset
1196411981
// by sizeof(Src2) = 4
1196511982
int SrcByteAdjust = 4;
1196611983

11967-
if (!PermOp.hasSameSrc(PermNodes[FirstSrc])) {
11968-
if (SecondSrc.has_value())
11969-
if (!PermOp.hasSameSrc(PermNodes[*SecondSrc]))
11984+
// If the Src uses a byte from a different DWORD, then it corresponds
11985+
// with a difference source
11986+
if (!PermOp.hasSameSrc(PermNodes[FirstSrc.first]) ||
11987+
((PermOp.SrcOffset / 4) != FirstSrc.second)) {
11988+
if (SecondSrc)
11989+
if (!PermOp.hasSameSrc(PermNodes[SecondSrc->first]) ||
11990+
((PermOp.SrcOffset / 4) != SecondSrc->second))
1197011991
return SDValue();
1197111992

1197211993
// Set the index of the second distinct Src node
11973-
SecondSrc = i;
11974-
assert(!(PermNodes[*SecondSrc].Src->getValueSizeInBits() % 8));
11994+
SecondSrc = {i, PermNodes[i].SrcOffset / 4};
11995+
assert(!(PermNodes[SecondSrc->first].Src->getValueSizeInBits() % 8));
1197511996
SrcByteAdjust = 0;
1197611997
}
11977-
assert(PermOp.SrcOffset + SrcByteAdjust < 8);
11998+
assert((PermOp.SrcOffset % 4) + SrcByteAdjust < 8);
1197811999
assert(!DAG.getDataLayout().isBigEndian());
11979-
PermMask |= (PermOp.SrcOffset + SrcByteAdjust) << (i * 8);
12000+
PermMask |= ((PermOp.SrcOffset % 4) + SrcByteAdjust) << (i * 8);
1198012001
}
11981-
11982-
SDValue Op = *PermNodes[FirstSrc].Src;
11983-
SDValue OtherOp = SecondSrc.has_value() ? *PermNodes[*SecondSrc].Src
11984-
: *PermNodes[FirstSrc].Src;
11985-
11986-
// Check that we haven't just recreated the same FSHR node.
11987-
if (N->getOpcode() == ISD::FSHR &&
11988-
(N->getOperand(0) == Op || N->getOperand(0) == OtherOp) &&
11989-
(N->getOperand(1) == Op || N->getOperand(1) == OtherOp))
11990-
return SDValue();
12002+
SDLoc DL(N);
12003+
SDValue Op = *PermNodes[FirstSrc.first].Src;
12004+
Op = getDWordFromOffset(DAG, DL, Op, FirstSrc.second);
12005+
assert(Op.getValueSizeInBits() == 32);
1199112006

1199212007
// Check that we are not just extracting the bytes in order from an op
11993-
if (Op == OtherOp && Op.getValueSizeInBits() == 32) {
12008+
if (!SecondSrc) {
1199412009
int Low16 = PermMask & 0xffff;
1199512010
int Hi16 = (PermMask & 0xffff0000) >> 16;
1199612011

@@ -12002,8 +12017,16 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1200212017
return DAG.getBitcast(MVT::getIntegerVT(32), Op);
1200312018
}
1200412019

12020+
SDValue OtherOp =
12021+
SecondSrc.has_value() ? *PermNodes[SecondSrc->first].Src : Op;
12022+
12023+
if (SecondSrc)
12024+
OtherOp = getDWordFromOffset(DAG, DL, OtherOp, SecondSrc->second);
12025+
12026+
assert(Op.getValueSizeInBits() == 32);
12027+
1200512028
if (hasNon16BitAccesses(PermMask, Op, OtherOp)) {
12006-
SDLoc DL(N);
12029+
1200712030
assert(Op.getValueType().isByteSized() &&
1200812031
OtherOp.getValueType().isByteSized());
1200912032

@@ -12018,7 +12041,6 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1201812041
return DAG.getNode(AMDGPUISD::PERM, DL, MVT::i32, Op, OtherOp,
1201912042
DAG.getConstant(PermMask, DL, MVT::i32));
1202012043
}
12021-
1202212044
return SDValue();
1202312045
}
1202412046

@@ -13530,17 +13552,24 @@ static unsigned addPermMasks(unsigned First, unsigned Second) {
1353013552
return (FirstNoCs | SecondNoCs) | (FirstCs & SecondCs);
1353113553
}
1353213554

13555+
struct DotSrc {
13556+
SDValue SrcOp;
13557+
int64_t PermMask;
13558+
int64_t DWordOffset;
13559+
};
13560+
1353313561
static void placeSources(ByteProvider<SDValue> &Src0,
1353413562
ByteProvider<SDValue> &Src1,
13535-
SmallVectorImpl<std::pair<SDValue, unsigned>> &Src0s,
13536-
SmallVectorImpl<std::pair<SDValue, unsigned>> &Src1s,
13537-
int Step) {
13563+
SmallVectorImpl<DotSrc> &Src0s,
13564+
SmallVectorImpl<DotSrc> &Src1s, int Step) {
1353813565

1353913566
assert(Src0.Src.has_value() && Src1.Src.has_value());
1354013567
// Src0s and Src1s are empty, just place arbitrarily.
1354113568
if (Step == 0) {
13542-
Src0s.push_back({*Src0.Src, (Src0.SrcOffset << 24) + 0x0c0c0c});
13543-
Src1s.push_back({*Src1.Src, (Src1.SrcOffset << 24) + 0x0c0c0c});
13569+
Src0s.push_back({*Src0.Src, ((Src0.SrcOffset % 4) << 24) + 0x0c0c0c,
13570+
Src0.SrcOffset / 4});
13571+
Src1s.push_back({*Src1.Src, ((Src1.SrcOffset % 4) << 24) + 0x0c0c0c,
13572+
Src1.SrcOffset / 4});
1354413573
return;
1354513574
}
1354613575

@@ -13553,38 +13582,38 @@ static void placeSources(ByteProvider<SDValue> &Src0,
1355313582
unsigned FMask = 0xFF << (8 * (3 - Step));
1355413583

1355513584
unsigned FirstMask =
13556-
BPP.first.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
13585+
(BPP.first.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask);
1355713586
unsigned SecondMask =
13558-
BPP.second.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
13587+
(BPP.second.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask);
1355913588
// Attempt to find Src vector which contains our SDValue, if so, add our
1356013589
// perm mask to the existing one. If we are unable to find a match for the
1356113590
// first SDValue, attempt to find match for the second.
1356213591
int FirstGroup = -1;
1356313592
for (int I = 0; I < 2; I++) {
13564-
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
13565-
I == 0 ? Src0s : Src1s;
13566-
auto MatchesFirst = [&BPP](std::pair<SDValue, unsigned> IterElt) {
13567-
return IterElt.first == *BPP.first.Src;
13593+
SmallVectorImpl<DotSrc> &Srcs = I == 0 ? Src0s : Src1s;
13594+
auto MatchesFirst = [&BPP](DotSrc &IterElt) {
13595+
return IterElt.SrcOp == *BPP.first.Src &&
13596+
(IterElt.DWordOffset == (BPP.first.SrcOffset / 4));
1356813597
};
1356913598

1357013599
auto Match = llvm::find_if(Srcs, MatchesFirst);
1357113600
if (Match != Srcs.end()) {
13572-
Match->second = addPermMasks(FirstMask, Match->second);
13601+
Match->PermMask = addPermMasks(FirstMask, Match->PermMask);
1357313602
FirstGroup = I;
1357413603
break;
1357513604
}
1357613605
}
1357713606
if (FirstGroup != -1) {
13578-
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
13579-
FirstGroup == 1 ? Src0s : Src1s;
13580-
auto MatchesSecond = [&BPP](std::pair<SDValue, unsigned> IterElt) {
13581-
return IterElt.first == *BPP.second.Src;
13607+
SmallVectorImpl<DotSrc> &Srcs = FirstGroup == 1 ? Src0s : Src1s;
13608+
auto MatchesSecond = [&BPP](DotSrc &IterElt) {
13609+
return IterElt.SrcOp == *BPP.second.Src &&
13610+
(IterElt.DWordOffset == (BPP.second.SrcOffset / 4));
1358213611
};
1358313612
auto Match = llvm::find_if(Srcs, MatchesSecond);
1358413613
if (Match != Srcs.end()) {
13585-
Match->second = addPermMasks(SecondMask, Match->second);
13614+
Match->PermMask = addPermMasks(SecondMask, Match->PermMask);
1358613615
} else
13587-
Srcs.push_back({*BPP.second.Src, SecondMask});
13616+
Srcs.push_back({*BPP.second.Src, SecondMask, BPP.second.SrcOffset / 4});
1358813617
return;
1358913618
}
1359013619
}
@@ -13596,29 +13625,32 @@ static void placeSources(ByteProvider<SDValue> &Src0,
1359613625
unsigned FMask = 0xFF << (8 * (3 - Step));
1359713626

1359813627
Src0s.push_back(
13599-
{*Src0.Src, (Src0.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
13628+
{*Src0.Src,
13629+
((Src0.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask)),
13630+
Src1.SrcOffset / 4});
1360013631
Src1s.push_back(
13601-
{*Src1.Src, (Src1.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
13632+
{*Src1.Src,
13633+
((Src1.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask)),
13634+
Src1.SrcOffset / 4});
1360213635

1360313636
return;
1360413637
}
1360513638

13606-
static SDValue
13607-
resolveSources(SelectionDAG &DAG, SDLoc SL,
13608-
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
13609-
bool IsSigned, bool IsAny) {
13639+
static SDValue resolveSources(SelectionDAG &DAG, SDLoc SL,
13640+
SmallVectorImpl<DotSrc> &Srcs, bool IsSigned,
13641+
bool IsAny) {
1361013642

1361113643
// If we just have one source, just permute it accordingly.
1361213644
if (Srcs.size() == 1) {
1361313645
auto Elt = Srcs.begin();
13614-
auto EltVal = DAG.getBitcastedAnyExtOrTrunc(Elt->first, SL, MVT::i32);
13646+
auto EltOp = getDWordFromOffset(DAG, SL, Elt->SrcOp, Elt->DWordOffset);
1361513647

13616-
// v_perm will produce the original value.
13617-
if (Elt->second == 0x3020100)
13618-
return EltVal;
13648+
// v_perm will produce the original value
13649+
if (Elt->PermMask == 0x3020100)
13650+
return EltOp;
1361913651

13620-
return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal,
13621-
DAG.getConstant(Elt->second, SL, MVT::i32));
13652+
return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltOp, EltOp,
13653+
DAG.getConstant(Elt->PermMask, SL, MVT::i32));
1362213654
}
1362313655

1362413656
auto FirstElt = Srcs.begin();
@@ -13629,8 +13661,8 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
1362913661
// If we have multiple sources in the chain, combine them via perms (using
1363013662
// calculated perm mask) and Ors.
1363113663
while (true) {
13632-
auto FirstMask = FirstElt->second;
13633-
auto SecondMask = SecondElt->second;
13664+
auto FirstMask = FirstElt->PermMask;
13665+
auto SecondMask = SecondElt->PermMask;
1363413666

1363513667
unsigned FirstCs = FirstMask & 0x0c0c0c0c;
1363613668
unsigned FirstPlusFour = FirstMask | 0x04040404;
@@ -13640,9 +13672,9 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
1364013672

1364113673
auto PermMask = addPermMasks(FirstMask, SecondMask);
1364213674
auto FirstVal =
13643-
DAG.getBitcastedAnyExtOrTrunc(FirstElt->first, SL, MVT::i32);
13675+
getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
1364413676
auto SecondVal =
13645-
DAG.getBitcastedAnyExtOrTrunc(SecondElt->first, SL, MVT::i32);
13677+
getDWordFromOffset(DAG, SL, SecondElt->SrcOp, SecondElt->DWordOffset);
1364613678

1364713679
Perms.push_back(DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, FirstVal,
1364813680
SecondVal,
@@ -13656,12 +13688,12 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
1365613688
// If we only have a FirstElt, then just combine that into the cumulative
1365713689
// source node.
1365813690
if (SecondElt == Srcs.end()) {
13659-
auto EltVal =
13660-
DAG.getBitcastedAnyExtOrTrunc(FirstElt->first, SL, MVT::i32);
13691+
auto EltOp =
13692+
getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
1366113693

1366213694
Perms.push_back(
13663-
DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal,
13664-
DAG.getConstant(FirstElt->second, SL, MVT::i32)));
13695+
DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltOp, EltOp,
13696+
DAG.getConstant(FirstElt->PermMask, SL, MVT::i32)));
1366513697
break;
1366613698
}
1366713699
}
@@ -13672,9 +13704,8 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
1367213704
: Perms[0];
1367313705
}
1367413706

13675-
static void fixMasks(SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
13676-
unsigned ChainLength) {
13677-
for (auto &[EntryVal, EntryMask] : Srcs) {
13707+
static void fixMasks(SmallVectorImpl<DotSrc> &Srcs, unsigned ChainLength) {
13708+
for (auto &[EntryVal, EntryMask, EntryOffset] : Srcs) {
1367813709
EntryMask = EntryMask >> ((4 - ChainLength) * 8);
1367913710
auto ZeroMask = ChainLength == 2 ? 0x0c0c0000 : 0x0c000000;
1368013711
EntryMask += ZeroMask;
@@ -13774,8 +13805,8 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
1377413805
(Subtarget->hasDot1Insts() || Subtarget->hasDot8Insts())) {
1377513806
SDValue TempNode(N, 0);
1377613807
std::optional<bool> IsSigned;
13777-
SmallVector<std::pair<SDValue, unsigned>, 4> Src0s;
13778-
SmallVector<std::pair<SDValue, unsigned>, 4> Src1s;
13808+
SmallVector<DotSrc, 4> Src0s;
13809+
SmallVector<DotSrc, 4> Src1s;
1377913810
SmallVector<SDValue, 4> Src2s;
1378013811

1378113812
// Match the v_dot4 tree, while collecting src nodes.
@@ -13857,11 +13888,11 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
1385713888
// (commutation).
1385813889
bool UseOriginalSrc = false;
1385913890
if (ChainLength == 4 && Src0s.size() == 1 && Src1s.size() == 1 &&
13860-
Src0s.begin()->second == Src1s.begin()->second &&
13861-
Src0s.begin()->first.getValueSizeInBits() == 32 &&
13862-
Src1s.begin()->first.getValueSizeInBits() == 32) {
13891+
Src0s.begin()->PermMask == Src1s.begin()->PermMask &&
13892+
Src0s.begin()->SrcOp.getValueSizeInBits() >= 32 &&
13893+
Src1s.begin()->SrcOp.getValueSizeInBits() >= 32) {
1386313894
SmallVector<unsigned, 4> SrcBytes;
13864-
auto Src0Mask = Src0s.begin()->second;
13895+
auto Src0Mask = Src0s.begin()->PermMask;
1386513896
SrcBytes.push_back(Src0Mask & 0xFF000000);
1386613897
bool UniqueEntries = true;
1386713898
for (auto I = 1; I < 4; I++) {
@@ -13876,11 +13907,19 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
1387613907

1387713908
if (UniqueEntries) {
1387813909
UseOriginalSrc = true;
13879-
// Must be 32 bits to enter above conditional.
13880-
assert(Src0s.begin()->first.getValueSizeInBits() == 32);
13881-
assert(Src1s.begin()->first.getValueSizeInBits() == 32);
13882-
Src0 = DAG.getBitcast(MVT::getIntegerVT(32), Src0s.begin()->first);
13883-
Src1 = DAG.getBitcast(MVT::getIntegerVT(32), Src1s.begin()->first);
13910+
13911+
auto FirstElt = Src0s.begin();
13912+
auto FirstEltOp =
13913+
getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
13914+
13915+
auto SecondElt = Src1s.begin();
13916+
auto SecondEltOp = getDWordFromOffset(DAG, SL, SecondElt->SrcOp,
13917+
SecondElt->DWordOffset);
13918+
13919+
Src0 = DAG.getBitcastedAnyExtOrTrunc(FirstEltOp, SL,
13920+
MVT::getIntegerVT(32));
13921+
Src1 = DAG.getBitcastedAnyExtOrTrunc(SecondEltOp, SL,
13922+
MVT::getIntegerVT(32));
1388413923
}
1388513924
}
1388613925

0 commit comments

Comments
 (0)