Skip to content

Commit cdd6cc0

Browse files
committed
[AMDGPU] Accept arbitrary sized sources in CalculateByteProvider
This allows working with e.g. v8i8 / v16i8 sources. It is generally useful, but is primarily beneficial when allowing e.g. v8i8s to be passed to branches directly through registers. As such, this is the first in a series of patches to enable that work. However, it effects https://reviews.llvm.org/D155995, so it has been implemented on top of that. Differential Revision: https://reviews.llvm.org/D159036 Change-Id: Idfcb57dacd0c32cab040fe4dd4ac2ec762750664
1 parent 237adfc commit cdd6cc0

File tree

6 files changed

+1498
-113
lines changed

6 files changed

+1498
-113
lines changed

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 124 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -10943,8 +10943,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
1094310943
if (Depth >= 6)
1094410944
return std::nullopt;
1094510945

10946-
auto ValueSize = Op.getValueSizeInBits();
10947-
if (ValueSize != 8 && ValueSize != 16 && ValueSize != 32)
10946+
if (Op.getValueSizeInBits() < 8)
1094810947
return std::nullopt;
1094910948

1095010949
switch (Op->getOpcode()) {
@@ -11235,8 +11234,6 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1123511234
auto VecIdx = IdxOp->getZExtValue();
1123611235
auto ScalarSize = Op.getScalarValueSizeInBits();
1123711236
if (ScalarSize != 32) {
11238-
if ((VecIdx + 1) * ScalarSize > 32)
11239-
return std::nullopt;
1124011237
Index = ScalarSize == 8 ? VecIdx : VecIdx * 2 + Index;
1124111238
}
1124211239

@@ -11322,9 +11319,6 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
1132211319
int Low16 = PermMask & 0xffff;
1132311320
int Hi16 = (PermMask & 0xffff0000) >> 16;
1132411321

11325-
assert(Op.getValueType().isByteSized());
11326-
assert(OtherOp.getValueType().isByteSized());
11327-
1132811322
auto TempOp = peekThroughBitcasts(Op);
1132911323
auto TempOtherOp = peekThroughBitcasts(OtherOp);
1133011324

@@ -11342,15 +11336,38 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
1134211336
return !addresses16Bits(Low16) || !addresses16Bits(Hi16);
1134311337
}
1134411338

11339+
static SDValue getDWordFromOffset(SelectionDAG &DAG, SDLoc SL, SDValue Src,
11340+
unsigned DWordOffset) {
11341+
SDValue Ret;
11342+
if (Src.getValueSizeInBits() <= 32)
11343+
return DAG.getBitcastedAnyExtOrTrunc(Src, SL, MVT::i32);
11344+
11345+
if (Src.getValueSizeInBits() >= 256) {
11346+
assert(!(Src.getValueSizeInBits() % 32));
11347+
Ret = DAG.getBitcast(
11348+
MVT::getVectorVT(MVT::i32, Src.getValueSizeInBits() / 32), Src);
11349+
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, Ret,
11350+
DAG.getConstant(DWordOffset, SL, MVT::i32));
11351+
}
11352+
11353+
Ret = DAG.getBitcastedAnyExtOrTrunc(
11354+
Src, SL, MVT::getIntegerVT(Src.getValueSizeInBits()));
11355+
if (DWordOffset) {
11356+
auto Shifted = DAG.getNode(ISD::SRL, SL, Ret.getValueType(), Ret,
11357+
DAG.getConstant(DWordOffset * 32, SL, MVT::i32));
11358+
return DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, Shifted);
11359+
}
11360+
11361+
return DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
11362+
}
11363+
1134511364
static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1134611365
SelectionDAG &DAG = DCI.DAG;
1134711366
EVT VT = N->getValueType(0);
11348-
11349-
if (VT != MVT::i32)
11350-
return SDValue();
11367+
SmallVector<ByteProvider<SDValue>, 8> PermNodes;
1135111368

1135211369
// VT is known to be MVT::i32, so we need to provide 4 bytes.
11353-
SmallVector<ByteProvider<SDValue>, 8> PermNodes;
11370+
assert(VT == MVT::i32);
1135411371
for (int i = 0; i < 4; i++) {
1135511372
// Find the ByteProvider that provides the ith byte of the result of OR
1135611373
std::optional<ByteProvider<SDValue>> P =
@@ -11364,42 +11381,40 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1136411381
if (PermNodes.size() != 4)
1136511382
return SDValue();
1136611383

11367-
int FirstSrc = 0;
11368-
std::optional<int> SecondSrc;
11384+
std::pair<unsigned, unsigned> FirstSrc(0, PermNodes[0].SrcOffset / 4);
11385+
std::optional<std::pair<unsigned, unsigned>> SecondSrc;
1136911386
uint64_t PermMask = 0x00000000;
1137011387
for (size_t i = 0; i < PermNodes.size(); i++) {
1137111388
auto PermOp = PermNodes[i];
1137211389
// Since the mask is applied to Src1:Src2, Src1 bytes must be offset
1137311390
// by sizeof(Src2) = 4
1137411391
int SrcByteAdjust = 4;
1137511392

11376-
if (!PermOp.hasSameSrc(PermNodes[FirstSrc])) {
11377-
if (SecondSrc.has_value())
11378-
if (!PermOp.hasSameSrc(PermNodes[*SecondSrc]))
11393+
// If the Src uses a byte from a different DWORD, then it corresponds
11394+
// with a difference source
11395+
if (!PermOp.hasSameSrc(PermNodes[FirstSrc.first]) ||
11396+
((PermOp.SrcOffset / 4) != FirstSrc.second)) {
11397+
if (SecondSrc)
11398+
if (!PermOp.hasSameSrc(PermNodes[SecondSrc->first]) ||
11399+
((PermOp.SrcOffset / 4) != SecondSrc->second))
1137911400
return SDValue();
1138011401

1138111402
// Set the index of the second distinct Src node
11382-
SecondSrc = i;
11383-
assert(!(PermNodes[*SecondSrc].Src->getValueSizeInBits() % 8));
11403+
SecondSrc = {i, PermNodes[i].SrcOffset / 4};
11404+
assert(!(PermNodes[SecondSrc->first].Src->getValueSizeInBits() % 8));
1138411405
SrcByteAdjust = 0;
1138511406
}
11386-
assert(PermOp.SrcOffset + SrcByteAdjust < 8);
11407+
assert((PermOp.SrcOffset % 4) + SrcByteAdjust < 8);
1138711408
assert(!DAG.getDataLayout().isBigEndian());
11388-
PermMask |= (PermOp.SrcOffset + SrcByteAdjust) << (i * 8);
11409+
PermMask |= ((PermOp.SrcOffset % 4) + SrcByteAdjust) << (i * 8);
1138911410
}
11390-
11391-
SDValue Op = *PermNodes[FirstSrc].Src;
11392-
SDValue OtherOp = SecondSrc.has_value() ? *PermNodes[*SecondSrc].Src
11393-
: *PermNodes[FirstSrc].Src;
11394-
11395-
// Check that we haven't just recreated the same FSHR node.
11396-
if (N->getOpcode() == ISD::FSHR &&
11397-
(N->getOperand(0) == Op || N->getOperand(0) == OtherOp) &&
11398-
(N->getOperand(1) == Op || N->getOperand(1) == OtherOp))
11399-
return SDValue();
11411+
SDLoc DL(N);
11412+
SDValue Op = *PermNodes[FirstSrc.first].Src;
11413+
Op = getDWordFromOffset(DAG, DL, Op, FirstSrc.second);
11414+
assert(Op.getValueSizeInBits() == 32);
1140011415

1140111416
// Check that we are not just extracting the bytes in order from an op
11402-
if (Op == OtherOp && Op.getValueSizeInBits() == 32) {
11417+
if (!SecondSrc) {
1140311418
int Low16 = PermMask & 0xffff;
1140411419
int Hi16 = (PermMask & 0xffff0000) >> 16;
1140511420

@@ -11411,8 +11426,16 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1141111426
return DAG.getBitcast(MVT::getIntegerVT(32), Op);
1141211427
}
1141311428

11429+
SDValue OtherOp =
11430+
SecondSrc.has_value() ? *PermNodes[SecondSrc->first].Src : Op;
11431+
11432+
if (SecondSrc)
11433+
OtherOp = getDWordFromOffset(DAG, DL, OtherOp, SecondSrc->second);
11434+
11435+
assert(Op.getValueSizeInBits() == 32);
11436+
1141411437
if (hasNon16BitAccesses(PermMask, Op, OtherOp)) {
11415-
SDLoc DL(N);
11438+
1141611439
assert(Op.getValueType().isByteSized() &&
1141711440
OtherOp.getValueType().isByteSized());
1141811441

@@ -11427,7 +11450,6 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1142711450
return DAG.getNode(AMDGPUISD::PERM, DL, MVT::i32, Op, OtherOp,
1142811451
DAG.getConstant(PermMask, DL, MVT::i32));
1142911452
}
11430-
1143111453
return SDValue();
1143211454
}
1143311455

@@ -12903,17 +12925,24 @@ static unsigned addPermMasks(unsigned First, unsigned Second) {
1290312925
return (FirstNoCs | SecondNoCs) | (FirstCs & SecondCs);
1290412926
}
1290512927

12928+
struct DotSrc {
12929+
SDValue SrcOp;
12930+
int64_t PermMask;
12931+
int64_t DWordOffset;
12932+
};
12933+
1290612934
static void placeSources(ByteProvider<SDValue> &Src0,
1290712935
ByteProvider<SDValue> &Src1,
12908-
SmallVectorImpl<std::pair<SDValue, unsigned>> &Src0s,
12909-
SmallVectorImpl<std::pair<SDValue, unsigned>> &Src1s,
12910-
int Step) {
12936+
SmallVectorImpl<DotSrc> &Src0s,
12937+
SmallVectorImpl<DotSrc> &Src1s, int Step) {
1291112938

1291212939
assert(Src0.Src.has_value() && Src1.Src.has_value());
1291312940
// Src0s and Src1s are empty, just place arbitrarily.
1291412941
if (Step == 0) {
12915-
Src0s.push_back({*Src0.Src, (Src0.SrcOffset << 24) + 0x0c0c0c});
12916-
Src1s.push_back({*Src1.Src, (Src1.SrcOffset << 24) + 0x0c0c0c});
12942+
Src0s.push_back({*Src0.Src, ((Src0.SrcOffset % 4) << 24) + 0x0c0c0c,
12943+
Src0.SrcOffset / 4});
12944+
Src1s.push_back({*Src1.Src, ((Src1.SrcOffset % 4) << 24) + 0x0c0c0c,
12945+
Src1.SrcOffset / 4});
1291712946
return;
1291812947
}
1291912948

@@ -12926,38 +12955,38 @@ static void placeSources(ByteProvider<SDValue> &Src0,
1292612955
unsigned FMask = 0xFF << (8 * (3 - Step));
1292712956

1292812957
unsigned FirstMask =
12929-
BPP.first.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12958+
(BPP.first.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask);
1293012959
unsigned SecondMask =
12931-
BPP.second.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12960+
(BPP.second.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask);
1293212961
// Attempt to find Src vector which contains our SDValue, if so, add our
1293312962
// perm mask to the existing one. If we are unable to find a match for the
1293412963
// first SDValue, attempt to find match for the second.
1293512964
int FirstGroup = -1;
1293612965
for (int I = 0; I < 2; I++) {
12937-
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
12938-
I == 0 ? Src0s : Src1s;
12939-
auto MatchesFirst = [&BPP](std::pair<SDValue, unsigned> IterElt) {
12940-
return IterElt.first == *BPP.first.Src;
12966+
SmallVectorImpl<DotSrc> &Srcs = I == 0 ? Src0s : Src1s;
12967+
auto MatchesFirst = [&BPP](DotSrc &IterElt) {
12968+
return IterElt.SrcOp == *BPP.first.Src &&
12969+
(IterElt.DWordOffset == (BPP.first.SrcOffset / 4));
1294112970
};
1294212971

1294312972
auto Match = llvm::find_if(Srcs, MatchesFirst);
1294412973
if (Match != Srcs.end()) {
12945-
Match->second = addPermMasks(FirstMask, Match->second);
12974+
Match->PermMask = addPermMasks(FirstMask, Match->PermMask);
1294612975
FirstGroup = I;
1294712976
break;
1294812977
}
1294912978
}
1295012979
if (FirstGroup != -1) {
12951-
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
12952-
FirstGroup == 1 ? Src0s : Src1s;
12953-
auto MatchesSecond = [&BPP](std::pair<SDValue, unsigned> IterElt) {
12954-
return IterElt.first == *BPP.second.Src;
12980+
SmallVectorImpl<DotSrc> &Srcs = FirstGroup == 1 ? Src0s : Src1s;
12981+
auto MatchesSecond = [&BPP](DotSrc &IterElt) {
12982+
return IterElt.SrcOp == *BPP.second.Src &&
12983+
(IterElt.DWordOffset == (BPP.second.SrcOffset / 4));
1295512984
};
1295612985
auto Match = llvm::find_if(Srcs, MatchesSecond);
1295712986
if (Match != Srcs.end()) {
12958-
Match->second = addPermMasks(SecondMask, Match->second);
12987+
Match->PermMask = addPermMasks(SecondMask, Match->PermMask);
1295912988
} else
12960-
Srcs.push_back({*BPP.second.Src, SecondMask});
12989+
Srcs.push_back({*BPP.second.Src, SecondMask, BPP.second.SrcOffset / 4});
1296112990
return;
1296212991
}
1296312992
}
@@ -12969,29 +12998,32 @@ static void placeSources(ByteProvider<SDValue> &Src0,
1296912998
unsigned FMask = 0xFF << (8 * (3 - Step));
1297012999

1297113000
Src0s.push_back(
12972-
{*Src0.Src, (Src0.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
13001+
{*Src0.Src,
13002+
((Src0.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask)),
13003+
Src1.SrcOffset / 4});
1297313004
Src1s.push_back(
12974-
{*Src1.Src, (Src1.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
13005+
{*Src1.Src,
13006+
((Src1.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask)),
13007+
Src1.SrcOffset / 4});
1297513008

1297613009
return;
1297713010
}
1297813011

12979-
static SDValue
12980-
resolveSources(SelectionDAG &DAG, SDLoc SL,
12981-
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
12982-
bool IsSigned, bool IsAny) {
13012+
static SDValue resolveSources(SelectionDAG &DAG, SDLoc SL,
13013+
SmallVectorImpl<DotSrc> &Srcs, bool IsSigned,
13014+
bool IsAny) {
1298313015

1298413016
// If we just have one source, just permute it accordingly.
1298513017
if (Srcs.size() == 1) {
1298613018
auto Elt = Srcs.begin();
12987-
auto EltVal = DAG.getBitcastedAnyExtOrTrunc(Elt->first, SL, MVT::i32);
13019+
auto EltOp = getDWordFromOffset(DAG, SL, Elt->SrcOp, Elt->DWordOffset);
1298813020

12989-
// v_perm will produce the original value.
12990-
if (Elt->second == 0x3020100)
12991-
return EltVal;
13021+
// v_perm will produce the original value
13022+
if (Elt->PermMask == 0x3020100)
13023+
return EltOp;
1299213024

12993-
return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal,
12994-
DAG.getConstant(Elt->second, SL, MVT::i32));
13025+
return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltOp, EltOp,
13026+
DAG.getConstant(Elt->PermMask, SL, MVT::i32));
1299513027
}
1299613028

1299713029
auto FirstElt = Srcs.begin();
@@ -13002,8 +13034,8 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
1300213034
// If we have multiple sources in the chain, combine them via perms (using
1300313035
// calculated perm mask) and Ors.
1300413036
while (true) {
13005-
auto FirstMask = FirstElt->second;
13006-
auto SecondMask = SecondElt->second;
13037+
auto FirstMask = FirstElt->PermMask;
13038+
auto SecondMask = SecondElt->PermMask;
1300713039

1300813040
unsigned FirstCs = FirstMask & 0x0c0c0c0c;
1300913041
unsigned FirstPlusFour = FirstMask | 0x04040404;
@@ -13013,9 +13045,9 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
1301313045

1301413046
auto PermMask = addPermMasks(FirstMask, SecondMask);
1301513047
auto FirstVal =
13016-
DAG.getBitcastedAnyExtOrTrunc(FirstElt->first, SL, MVT::i32);
13048+
getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
1301713049
auto SecondVal =
13018-
DAG.getBitcastedAnyExtOrTrunc(SecondElt->first, SL, MVT::i32);
13050+
getDWordFromOffset(DAG, SL, SecondElt->SrcOp, SecondElt->DWordOffset);
1301913051

1302013052
Perms.push_back(DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, FirstVal,
1302113053
SecondVal,
@@ -13029,12 +13061,12 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
1302913061
// If we only have a FirstElt, then just combine that into the cumulative
1303013062
// source node.
1303113063
if (SecondElt == Srcs.end()) {
13032-
auto EltVal =
13033-
DAG.getBitcastedAnyExtOrTrunc(FirstElt->first, SL, MVT::i32);
13064+
auto EltOp =
13065+
getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
1303413066

1303513067
Perms.push_back(
13036-
DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal,
13037-
DAG.getConstant(FirstElt->second, SL, MVT::i32)));
13068+
DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltOp, EltOp,
13069+
DAG.getConstant(FirstElt->PermMask, SL, MVT::i32)));
1303813070
break;
1303913071
}
1304013072
}
@@ -13045,9 +13077,8 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
1304513077
: Perms[0];
1304613078
}
1304713079

13048-
static void fixMasks(SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
13049-
unsigned ChainLength) {
13050-
for (auto &[EntryVal, EntryMask] : Srcs) {
13080+
static void fixMasks(SmallVectorImpl<DotSrc> &Srcs, unsigned ChainLength) {
13081+
for (auto &[EntryVal, EntryMask, EntryOffset] : Srcs) {
1305113082
EntryMask = EntryMask >> ((4 - ChainLength) * 8);
1305213083
auto ZeroMask = ChainLength == 2 ? 0x0c0c0000 : 0x0c000000;
1305313084
EntryMask += ZeroMask;
@@ -13112,8 +13143,8 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
1311213143
(Subtarget->hasDot1Insts() || Subtarget->hasDot8Insts())) {
1311313144
SDValue TempNode(N, 0);
1311413145
std::optional<bool> IsSigned;
13115-
SmallVector<std::pair<SDValue, unsigned>, 4> Src0s;
13116-
SmallVector<std::pair<SDValue, unsigned>, 4> Src1s;
13146+
SmallVector<DotSrc, 4> Src0s;
13147+
SmallVector<DotSrc, 4> Src1s;
1311713148
SmallVector<SDValue, 4> Src2s;
1311813149

1311913150
// Match the v_dot4 tree, while collecting src nodes.
@@ -13191,11 +13222,11 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
1319113222
// (commutation).
1319213223
bool UseOriginalSrc = false;
1319313224
if (ChainLength == 4 && Src0s.size() == 1 && Src1s.size() == 1 &&
13194-
Src0s.begin()->second == Src1s.begin()->second &&
13195-
Src0s.begin()->first.getValueSizeInBits() == 32 &&
13196-
Src1s.begin()->first.getValueSizeInBits() == 32) {
13225+
Src0s.begin()->PermMask == Src1s.begin()->PermMask &&
13226+
Src0s.begin()->SrcOp.getValueSizeInBits() >= 32 &&
13227+
Src1s.begin()->SrcOp.getValueSizeInBits() >= 32) {
1319713228
SmallVector<unsigned, 4> SrcBytes;
13198-
auto Src0Mask = Src0s.begin()->second;
13229+
auto Src0Mask = Src0s.begin()->PermMask;
1319913230
SrcBytes.push_back(Src0Mask & 0xFF000000);
1320013231
bool UniqueEntries = true;
1320113232
for (auto I = 1; I < 4; I++) {
@@ -13210,11 +13241,19 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
1321013241

1321113242
if (UniqueEntries) {
1321213243
UseOriginalSrc = true;
13213-
// Must be 32 bits to enter above conditional.
13214-
assert(Src0s.begin()->first.getValueSizeInBits() == 32);
13215-
assert(Src1s.begin()->first.getValueSizeInBits() == 32);
13216-
Src0 = DAG.getBitcast(MVT::getIntegerVT(32), Src0s.begin()->first);
13217-
Src1 = DAG.getBitcast(MVT::getIntegerVT(32), Src1s.begin()->first);
13244+
13245+
auto FirstElt = Src0s.begin();
13246+
auto FirstEltOp =
13247+
getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
13248+
13249+
auto SecondElt = Src1s.begin();
13250+
auto SecondEltOp = getDWordFromOffset(DAG, SL, SecondElt->SrcOp,
13251+
SecondElt->DWordOffset);
13252+
13253+
Src0 = DAG.getBitcastedAnyExtOrTrunc(FirstEltOp, SL,
13254+
MVT::getIntegerVT(32));
13255+
Src1 = DAG.getBitcastedAnyExtOrTrunc(SecondEltOp, SL,
13256+
MVT::getIntegerVT(32));
1321813257
}
1321913258
}
1322013259

0 commit comments

Comments
 (0)