Skip to content

Commit c82ebfb

Browse files
committed
Revert "[AMDGPU] Accept arbitrary sized sources in CalculateByteProvider"
This reverts commit ef33659. It was causing incorrect codegen for some Vulkan CTS tests.
1 parent 9d35387 commit c82ebfb

File tree

6 files changed

+114
-1499
lines changed

6 files changed

+114
-1499
lines changed

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 86 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -10834,7 +10834,8 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
1083410834
if (Depth >= 6)
1083510835
return std::nullopt;
1083610836

10837-
if (Op.getValueSizeInBits() < 8)
10837+
auto ValueSize = Op.getValueSizeInBits();
10838+
if (ValueSize != 8 && ValueSize != 16 && ValueSize != 32)
1083810839
return std::nullopt;
1083910840

1084010841
switch (Op->getOpcode()) {
@@ -11125,6 +11126,8 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1112511126
auto VecIdx = IdxOp->getZExtValue();
1112611127
auto ScalarSize = Op.getScalarValueSizeInBits();
1112711128
if (ScalarSize != 32) {
11129+
if ((VecIdx + 1) * ScalarSize > 32)
11130+
return std::nullopt;
1112811131
Index = ScalarSize == 8 ? VecIdx : VecIdx * 2 + Index;
1112911132
}
1113011133

@@ -11210,6 +11213,9 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
1121011213
int Low16 = PermMask & 0xffff;
1121111214
int Hi16 = (PermMask & 0xffff0000) >> 16;
1121211215

11216+
assert(Op.getValueType().isByteSized());
11217+
assert(OtherOp.getValueType().isByteSized());
11218+
1121311219
auto TempOp = peekThroughBitcasts(Op);
1121411220
auto TempOtherOp = peekThroughBitcasts(OtherOp);
1121511221

@@ -11227,38 +11233,15 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
1122711233
return !addresses16Bits(Low16) || !addresses16Bits(Hi16);
1122811234
}
1122911235

11230-
static SDValue getDWordFromOffset(SelectionDAG &DAG, SDLoc SL, SDValue Src,
11231-
unsigned DWordOffset) {
11232-
SDValue Ret;
11233-
if (Src.getValueSizeInBits() <= 32)
11234-
return DAG.getBitcastedAnyExtOrTrunc(Src, SL, MVT::i32);
11235-
11236-
if (Src.getValueSizeInBits() >= 256) {
11237-
assert(!(Src.getValueSizeInBits() % 32));
11238-
Ret = DAG.getBitcast(
11239-
MVT::getVectorVT(MVT::i32, Src.getValueSizeInBits() / 32), Src);
11240-
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, Ret,
11241-
DAG.getConstant(DWordOffset, SL, MVT::i32));
11242-
}
11243-
11244-
Ret = DAG.getBitcastedAnyExtOrTrunc(
11245-
Src, SL, MVT::getIntegerVT(Src.getValueSizeInBits()));
11246-
if (DWordOffset) {
11247-
auto Shifted = DAG.getNode(ISD::SRL, SL, Ret.getValueType(), Ret,
11248-
DAG.getConstant(DWordOffset * 32, SL, MVT::i32));
11249-
return DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, Shifted);
11250-
}
11251-
11252-
return DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
11253-
}
11254-
1125511236
static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1125611237
SelectionDAG &DAG = DCI.DAG;
11257-
[[maybe_unused]] EVT VT = N->getValueType(0);
11258-
SmallVector<ByteProvider<SDValue>, 8> PermNodes;
11238+
EVT VT = N->getValueType(0);
11239+
11240+
if (VT != MVT::i32)
11241+
return SDValue();
1125911242

1126011243
// VT is known to be MVT::i32, so we need to provide 4 bytes.
11261-
assert(VT == MVT::i32);
11244+
SmallVector<ByteProvider<SDValue>, 8> PermNodes;
1126211245
for (int i = 0; i < 4; i++) {
1126311246
// Find the ByteProvider that provides the ith byte of the result of OR
1126411247
std::optional<ByteProvider<SDValue>> P =
@@ -11272,40 +11255,42 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1127211255
if (PermNodes.size() != 4)
1127311256
return SDValue();
1127411257

11275-
std::pair<unsigned, unsigned> FirstSrc(0, PermNodes[0].SrcOffset / 4);
11276-
std::optional<std::pair<unsigned, unsigned>> SecondSrc;
11258+
int FirstSrc = 0;
11259+
std::optional<int> SecondSrc;
1127711260
uint64_t PermMask = 0x00000000;
1127811261
for (size_t i = 0; i < PermNodes.size(); i++) {
1127911262
auto PermOp = PermNodes[i];
1128011263
// Since the mask is applied to Src1:Src2, Src1 bytes must be offset
1128111264
// by sizeof(Src2) = 4
1128211265
int SrcByteAdjust = 4;
1128311266

11284-
// If the Src uses a byte from a different DWORD, then it corresponds
11285-
// with a difference source
11286-
if (!PermOp.hasSameSrc(PermNodes[FirstSrc.first]) ||
11287-
((PermOp.SrcOffset / 4) != FirstSrc.second)) {
11288-
if (SecondSrc)
11289-
if (!PermOp.hasSameSrc(PermNodes[SecondSrc->first]) ||
11290-
((PermOp.SrcOffset / 4) != SecondSrc->second))
11267+
if (!PermOp.hasSameSrc(PermNodes[FirstSrc])) {
11268+
if (SecondSrc.has_value())
11269+
if (!PermOp.hasSameSrc(PermNodes[*SecondSrc]))
1129111270
return SDValue();
1129211271

1129311272
// Set the index of the second distinct Src node
11294-
SecondSrc = {i, PermNodes[i].SrcOffset / 4};
11295-
assert(!(PermNodes[SecondSrc->first].Src->getValueSizeInBits() % 8));
11273+
SecondSrc = i;
11274+
assert(!(PermNodes[*SecondSrc].Src->getValueSizeInBits() % 8));
1129611275
SrcByteAdjust = 0;
1129711276
}
11298-
assert((PermOp.SrcOffset % 4) + SrcByteAdjust < 8);
11277+
assert(PermOp.SrcOffset + SrcByteAdjust < 8);
1129911278
assert(!DAG.getDataLayout().isBigEndian());
11300-
PermMask |= ((PermOp.SrcOffset % 4) + SrcByteAdjust) << (i * 8);
11279+
PermMask |= (PermOp.SrcOffset + SrcByteAdjust) << (i * 8);
1130111280
}
11302-
SDLoc DL(N);
11303-
SDValue Op = *PermNodes[FirstSrc.first].Src;
11304-
Op = getDWordFromOffset(DAG, DL, Op, FirstSrc.second);
11305-
assert(Op.getValueSizeInBits() == 32);
11281+
11282+
SDValue Op = *PermNodes[FirstSrc].Src;
11283+
SDValue OtherOp = SecondSrc.has_value() ? *PermNodes[*SecondSrc].Src
11284+
: *PermNodes[FirstSrc].Src;
11285+
11286+
// Check that we haven't just recreated the same FSHR node.
11287+
if (N->getOpcode() == ISD::FSHR &&
11288+
(N->getOperand(0) == Op || N->getOperand(0) == OtherOp) &&
11289+
(N->getOperand(1) == Op || N->getOperand(1) == OtherOp))
11290+
return SDValue();
1130611291

1130711292
// Check that we are not just extracting the bytes in order from an op
11308-
if (!SecondSrc) {
11293+
if (Op == OtherOp && Op.getValueSizeInBits() == 32) {
1130911294
int Low16 = PermMask & 0xffff;
1131011295
int Hi16 = (PermMask & 0xffff0000) >> 16;
1131111296

@@ -11317,16 +11302,8 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1131711302
return DAG.getBitcast(MVT::getIntegerVT(32), Op);
1131811303
}
1131911304

11320-
SDValue OtherOp =
11321-
SecondSrc.has_value() ? *PermNodes[SecondSrc->first].Src : Op;
11322-
11323-
if (SecondSrc)
11324-
OtherOp = getDWordFromOffset(DAG, DL, OtherOp, SecondSrc->second);
11325-
11326-
assert(Op.getValueSizeInBits() == 32);
11327-
1132811305
if (hasNon16BitAccesses(PermMask, Op, OtherOp)) {
11329-
11306+
SDLoc DL(N);
1133011307
assert(Op.getValueType().isByteSized() &&
1133111308
OtherOp.getValueType().isByteSized());
1133211309

@@ -11341,6 +11318,7 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1134111318
return DAG.getNode(AMDGPUISD::PERM, DL, MVT::i32, Op, OtherOp,
1134211319
DAG.getConstant(PermMask, DL, MVT::i32));
1134311320
}
11321+
1134411322
return SDValue();
1134511323
}
1134611324

@@ -12816,24 +12794,17 @@ static unsigned addPermMasks(unsigned First, unsigned Second) {
1281612794
return (FirstNoCs | SecondNoCs) | (FirstCs & SecondCs);
1281712795
}
1281812796

12819-
struct DotSrc {
12820-
SDValue SrcOp;
12821-
int64_t PermMask;
12822-
int64_t DWordOffset;
12823-
};
12824-
1282512797
static void placeSources(ByteProvider<SDValue> &Src0,
1282612798
ByteProvider<SDValue> &Src1,
12827-
SmallVectorImpl<DotSrc> &Src0s,
12828-
SmallVectorImpl<DotSrc> &Src1s, int Step) {
12799+
SmallVectorImpl<std::pair<SDValue, unsigned>> &Src0s,
12800+
SmallVectorImpl<std::pair<SDValue, unsigned>> &Src1s,
12801+
int Step) {
1282912802

1283012803
assert(Src0.Src.has_value() && Src1.Src.has_value());
1283112804
// Src0s and Src1s are empty, just place arbitrarily.
1283212805
if (Step == 0) {
12833-
Src0s.push_back({*Src0.Src, ((Src0.SrcOffset % 4) << 24) + 0x0c0c0c,
12834-
Src0.SrcOffset / 4});
12835-
Src1s.push_back({*Src1.Src, ((Src1.SrcOffset % 4) << 24) + 0x0c0c0c,
12836-
Src1.SrcOffset / 4});
12806+
Src0s.push_back({*Src0.Src, (Src0.SrcOffset << 24) + 0x0c0c0c});
12807+
Src1s.push_back({*Src1.Src, (Src1.SrcOffset << 24) + 0x0c0c0c});
1283712808
return;
1283812809
}
1283912810

@@ -12846,38 +12817,38 @@ static void placeSources(ByteProvider<SDValue> &Src0,
1284612817
unsigned FMask = 0xFF << (8 * (3 - Step));
1284712818

1284812819
unsigned FirstMask =
12849-
(BPP.first.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12820+
BPP.first.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
1285012821
unsigned SecondMask =
12851-
(BPP.second.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12822+
BPP.second.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
1285212823
// Attempt to find Src vector which contains our SDValue, if so, add our
1285312824
// perm mask to the existing one. If we are unable to find a match for the
1285412825
// first SDValue, attempt to find match for the second.
1285512826
int FirstGroup = -1;
1285612827
for (int I = 0; I < 2; I++) {
12857-
SmallVectorImpl<DotSrc> &Srcs = I == 0 ? Src0s : Src1s;
12858-
auto MatchesFirst = [&BPP](DotSrc &IterElt) {
12859-
return IterElt.SrcOp == *BPP.first.Src &&
12860-
(IterElt.DWordOffset == (BPP.first.SrcOffset / 4));
12828+
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
12829+
I == 0 ? Src0s : Src1s;
12830+
auto MatchesFirst = [&BPP](std::pair<SDValue, unsigned> IterElt) {
12831+
return IterElt.first == *BPP.first.Src;
1286112832
};
1286212833

1286312834
auto Match = llvm::find_if(Srcs, MatchesFirst);
1286412835
if (Match != Srcs.end()) {
12865-
Match->PermMask = addPermMasks(FirstMask, Match->PermMask);
12836+
Match->second = addPermMasks(FirstMask, Match->second);
1286612837
FirstGroup = I;
1286712838
break;
1286812839
}
1286912840
}
1287012841
if (FirstGroup != -1) {
12871-
SmallVectorImpl<DotSrc> &Srcs = FirstGroup == 1 ? Src0s : Src1s;
12872-
auto MatchesSecond = [&BPP](DotSrc &IterElt) {
12873-
return IterElt.SrcOp == *BPP.second.Src &&
12874-
(IterElt.DWordOffset == (BPP.second.SrcOffset / 4));
12842+
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
12843+
FirstGroup == 1 ? Src0s : Src1s;
12844+
auto MatchesSecond = [&BPP](std::pair<SDValue, unsigned> IterElt) {
12845+
return IterElt.first == *BPP.second.Src;
1287512846
};
1287612847
auto Match = llvm::find_if(Srcs, MatchesSecond);
1287712848
if (Match != Srcs.end()) {
12878-
Match->PermMask = addPermMasks(SecondMask, Match->PermMask);
12849+
Match->second = addPermMasks(SecondMask, Match->second);
1287912850
} else
12880-
Srcs.push_back({*BPP.second.Src, SecondMask, BPP.second.SrcOffset / 4});
12851+
Srcs.push_back({*BPP.second.Src, SecondMask});
1288112852
return;
1288212853
}
1288312854
}
@@ -12889,32 +12860,29 @@ static void placeSources(ByteProvider<SDValue> &Src0,
1288912860
unsigned FMask = 0xFF << (8 * (3 - Step));
1289012861

1289112862
Src0s.push_back(
12892-
{*Src0.Src,
12893-
((Src0.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask)),
12894-
Src1.SrcOffset / 4});
12863+
{*Src0.Src, (Src0.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
1289512864
Src1s.push_back(
12896-
{*Src1.Src,
12897-
((Src1.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask)),
12898-
Src1.SrcOffset / 4});
12865+
{*Src1.Src, (Src1.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
1289912866

1290012867
return;
1290112868
}
1290212869

12903-
static SDValue resolveSources(SelectionDAG &DAG, SDLoc SL,
12904-
SmallVectorImpl<DotSrc> &Srcs, bool IsSigned,
12905-
bool IsAny) {
12870+
static SDValue
12871+
resolveSources(SelectionDAG &DAG, SDLoc SL,
12872+
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
12873+
bool IsSigned, bool IsAny) {
1290612874

1290712875
// If we just have one source, just permute it accordingly.
1290812876
if (Srcs.size() == 1) {
1290912877
auto Elt = Srcs.begin();
12910-
auto EltOp = getDWordFromOffset(DAG, SL, Elt->SrcOp, Elt->DWordOffset);
12878+
auto EltVal = DAG.getBitcastedAnyExtOrTrunc(Elt->first, SL, MVT::i32);
1291112879

12912-
// v_perm will produce the original value
12913-
if (Elt->PermMask == 0x3020100)
12914-
return EltOp;
12880+
// v_perm will produce the original value.
12881+
if (Elt->second == 0x3020100)
12882+
return EltVal;
1291512883

12916-
return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltOp, EltOp,
12917-
DAG.getConstant(Elt->PermMask, SL, MVT::i32));
12884+
return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal,
12885+
DAG.getConstant(Elt->second, SL, MVT::i32));
1291812886
}
1291912887

1292012888
auto FirstElt = Srcs.begin();
@@ -12925,8 +12893,8 @@ static SDValue resolveSources(SelectionDAG &DAG, SDLoc SL,
1292512893
// If we have multiple sources in the chain, combine them via perms (using
1292612894
// calculated perm mask) and Ors.
1292712895
while (true) {
12928-
auto FirstMask = FirstElt->PermMask;
12929-
auto SecondMask = SecondElt->PermMask;
12896+
auto FirstMask = FirstElt->second;
12897+
auto SecondMask = SecondElt->second;
1293012898

1293112899
unsigned FirstCs = FirstMask & 0x0c0c0c0c;
1293212900
unsigned FirstPlusFour = FirstMask | 0x04040404;
@@ -12936,9 +12904,9 @@ static SDValue resolveSources(SelectionDAG &DAG, SDLoc SL,
1293612904

1293712905
auto PermMask = addPermMasks(FirstMask, SecondMask);
1293812906
auto FirstVal =
12939-
getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
12907+
DAG.getBitcastedAnyExtOrTrunc(FirstElt->first, SL, MVT::i32);
1294012908
auto SecondVal =
12941-
getDWordFromOffset(DAG, SL, SecondElt->SrcOp, SecondElt->DWordOffset);
12909+
DAG.getBitcastedAnyExtOrTrunc(SecondElt->first, SL, MVT::i32);
1294212910

1294312911
Perms.push_back(DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, FirstVal,
1294412912
SecondVal,
@@ -12952,12 +12920,12 @@ static SDValue resolveSources(SelectionDAG &DAG, SDLoc SL,
1295212920
// If we only have a FirstElt, then just combine that into the cumulative
1295312921
// source node.
1295412922
if (SecondElt == Srcs.end()) {
12955-
auto EltOp =
12956-
getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
12923+
auto EltVal =
12924+
DAG.getBitcastedAnyExtOrTrunc(FirstElt->first, SL, MVT::i32);
1295712925

1295812926
Perms.push_back(
12959-
DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltOp, EltOp,
12960-
DAG.getConstant(FirstElt->PermMask, SL, MVT::i32)));
12927+
DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal,
12928+
DAG.getConstant(FirstElt->second, SL, MVT::i32)));
1296112929
break;
1296212930
}
1296312931
}
@@ -12968,8 +12936,9 @@ static SDValue resolveSources(SelectionDAG &DAG, SDLoc SL,
1296812936
: Perms[0];
1296912937
}
1297012938

12971-
static void fixMasks(SmallVectorImpl<DotSrc> &Srcs, unsigned ChainLength) {
12972-
for (auto &[EntryVal, EntryMask, EntryOffset] : Srcs) {
12939+
static void fixMasks(SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
12940+
unsigned ChainLength) {
12941+
for (auto &[EntryVal, EntryMask] : Srcs) {
1297312942
EntryMask = EntryMask >> ((4 - ChainLength) * 8);
1297412943
auto ZeroMask = ChainLength == 2 ? 0x0c0c0000 : 0x0c000000;
1297512944
EntryMask += ZeroMask;
@@ -13034,8 +13003,8 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
1303413003
(Subtarget->hasDot1Insts() || Subtarget->hasDot8Insts())) {
1303513004
SDValue TempNode(N, 0);
1303613005
std::optional<bool> IsSigned;
13037-
SmallVector<DotSrc, 4> Src0s;
13038-
SmallVector<DotSrc, 4> Src1s;
13006+
SmallVector<std::pair<SDValue, unsigned>, 4> Src0s;
13007+
SmallVector<std::pair<SDValue, unsigned>, 4> Src1s;
1303913008
SmallVector<SDValue, 4> Src2s;
1304013009

1304113010
// Match the v_dot4 tree, while collecting src nodes.
@@ -13113,11 +13082,11 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
1311313082
// (commutation).
1311413083
bool UseOriginalSrc = false;
1311513084
if (ChainLength == 4 && Src0s.size() == 1 && Src1s.size() == 1 &&
13116-
Src0s.begin()->PermMask == Src1s.begin()->PermMask &&
13117-
Src0s.begin()->SrcOp.getValueSizeInBits() >= 32 &&
13118-
Src1s.begin()->SrcOp.getValueSizeInBits() >= 32) {
13085+
Src0s.begin()->second == Src1s.begin()->second &&
13086+
Src0s.begin()->first.getValueSizeInBits() == 32 &&
13087+
Src1s.begin()->first.getValueSizeInBits() == 32) {
1311913088
SmallVector<unsigned, 4> SrcBytes;
13120-
auto Src0Mask = Src0s.begin()->PermMask;
13089+
auto Src0Mask = Src0s.begin()->second;
1312113090
SrcBytes.push_back(Src0Mask & 0xFF000000);
1312213091
bool UniqueEntries = true;
1312313092
for (auto I = 1; I < 4; I++) {
@@ -13132,19 +13101,11 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
1313213101

1313313102
if (UniqueEntries) {
1313413103
UseOriginalSrc = true;
13135-
13136-
auto FirstElt = Src0s.begin();
13137-
auto FirstEltOp =
13138-
getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
13139-
13140-
auto SecondElt = Src1s.begin();
13141-
auto SecondEltOp = getDWordFromOffset(DAG, SL, SecondElt->SrcOp,
13142-
SecondElt->DWordOffset);
13143-
13144-
Src0 = DAG.getBitcastedAnyExtOrTrunc(FirstEltOp, SL,
13145-
MVT::getIntegerVT(32));
13146-
Src1 = DAG.getBitcastedAnyExtOrTrunc(SecondEltOp, SL,
13147-
MVT::getIntegerVT(32));
13104+
// Must be 32 bits to enter above conditional.
13105+
assert(Src0s.begin()->first.getValueSizeInBits() == 32);
13106+
assert(Src1s.begin()->first.getValueSizeInBits() == 32);
13107+
Src0 = DAG.getBitcast(MVT::getIntegerVT(32), Src0s.begin()->first);
13108+
Src1 = DAG.getBitcast(MVT::getIntegerVT(32), Src1s.begin()->first);
1314813109
}
1314913110
}
1315013111

0 commit comments

Comments
 (0)