@@ -10943,8 +10943,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
10943
10943
if (Depth >= 6)
10944
10944
return std::nullopt;
10945
10945
10946
- auto ValueSize = Op.getValueSizeInBits();
10947
- if (ValueSize != 8 && ValueSize != 16 && ValueSize != 32)
10946
+ if (Op.getValueSizeInBits() < 8)
10948
10947
return std::nullopt;
10949
10948
10950
10949
switch (Op->getOpcode()) {
@@ -11235,8 +11234,6 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11235
11234
auto VecIdx = IdxOp->getZExtValue();
11236
11235
auto ScalarSize = Op.getScalarValueSizeInBits();
11237
11236
if (ScalarSize != 32) {
11238
- if ((VecIdx + 1) * ScalarSize > 32)
11239
- return std::nullopt;
11240
11237
Index = ScalarSize == 8 ? VecIdx : VecIdx * 2 + Index;
11241
11238
}
11242
11239
@@ -11322,9 +11319,6 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
11322
11319
int Low16 = PermMask & 0xffff;
11323
11320
int Hi16 = (PermMask & 0xffff0000) >> 16;
11324
11321
11325
- assert(Op.getValueType().isByteSized());
11326
- assert(OtherOp.getValueType().isByteSized());
11327
-
11328
11322
auto TempOp = peekThroughBitcasts(Op);
11329
11323
auto TempOtherOp = peekThroughBitcasts(OtherOp);
11330
11324
@@ -11342,15 +11336,38 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
11342
11336
return !addresses16Bits(Low16) || !addresses16Bits(Hi16);
11343
11337
}
11344
11338
11339
+ static SDValue getDWordFromOffset(SelectionDAG &DAG, SDLoc SL, SDValue Src,
11340
+ unsigned DWordOffset) {
11341
+ SDValue Ret;
11342
+ if (Src.getValueSizeInBits() <= 32)
11343
+ return DAG.getBitcastedAnyExtOrTrunc(Src, SL, MVT::i32);
11344
+
11345
+ if (Src.getValueSizeInBits() >= 256) {
11346
+ assert(!(Src.getValueSizeInBits() % 32));
11347
+ Ret = DAG.getBitcast(
11348
+ MVT::getVectorVT(MVT::i32, Src.getValueSizeInBits() / 32), Src);
11349
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, Ret,
11350
+ DAG.getConstant(DWordOffset, SL, MVT::i32));
11351
+ }
11352
+
11353
+ Ret = DAG.getBitcastedAnyExtOrTrunc(
11354
+ Src, SL, MVT::getIntegerVT(Src.getValueSizeInBits()));
11355
+ if (DWordOffset) {
11356
+ auto Shifted = DAG.getNode(ISD::SRL, SL, Ret.getValueType(), Ret,
11357
+ DAG.getConstant(DWordOffset * 32, SL, MVT::i32));
11358
+ return DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, Shifted);
11359
+ }
11360
+
11361
+ return DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
11362
+ }
11363
+
11345
11364
static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
11346
11365
SelectionDAG &DAG = DCI.DAG;
11347
11366
EVT VT = N->getValueType(0);
11348
-
11349
- if (VT != MVT::i32)
11350
- return SDValue();
11367
+ SmallVector<ByteProvider<SDValue>, 8> PermNodes;
11351
11368
11352
11369
// VT is known to be MVT::i32, so we need to provide 4 bytes.
11353
- SmallVector<ByteProvider<SDValue>, 8> PermNodes ;
11370
+ assert(VT == MVT::i32) ;
11354
11371
for (int i = 0; i < 4; i++) {
11355
11372
// Find the ByteProvider that provides the ith byte of the result of OR
11356
11373
std::optional<ByteProvider<SDValue>> P =
@@ -11364,42 +11381,40 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
11364
11381
if (PermNodes.size() != 4)
11365
11382
return SDValue();
11366
11383
11367
- int FirstSrc = 0 ;
11368
- std::optional<int > SecondSrc;
11384
+ std::pair<unsigned, unsigned> FirstSrc(0, PermNodes[0].SrcOffset / 4) ;
11385
+ std::optional<std::pair<unsigned, unsigned> > SecondSrc;
11369
11386
uint64_t PermMask = 0x00000000;
11370
11387
for (size_t i = 0; i < PermNodes.size(); i++) {
11371
11388
auto PermOp = PermNodes[i];
11372
11389
// Since the mask is applied to Src1:Src2, Src1 bytes must be offset
11373
11390
// by sizeof(Src2) = 4
11374
11391
int SrcByteAdjust = 4;
11375
11392
11376
- if (!PermOp.hasSameSrc(PermNodes[FirstSrc])) {
11377
- if (SecondSrc.has_value())
11378
- if (!PermOp.hasSameSrc(PermNodes[*SecondSrc]))
11393
+ // If the Src uses a byte from a different DWORD, then it corresponds
11394
+ // with a difference source
11395
+ if (!PermOp.hasSameSrc(PermNodes[FirstSrc.first]) ||
11396
+ ((PermOp.SrcOffset / 4) != FirstSrc.second)) {
11397
+ if (SecondSrc)
11398
+ if (!PermOp.hasSameSrc(PermNodes[SecondSrc->first]) ||
11399
+ ((PermOp.SrcOffset / 4) != SecondSrc->second))
11379
11400
return SDValue();
11380
11401
11381
11402
// Set the index of the second distinct Src node
11382
- SecondSrc = i ;
11383
- assert(!(PermNodes[* SecondSrc].Src->getValueSizeInBits() % 8));
11403
+ SecondSrc = {i, PermNodes[i].SrcOffset / 4} ;
11404
+ assert(!(PermNodes[SecondSrc->first ].Src->getValueSizeInBits() % 8));
11384
11405
SrcByteAdjust = 0;
11385
11406
}
11386
- assert(PermOp.SrcOffset + SrcByteAdjust < 8);
11407
+ assert(( PermOp.SrcOffset % 4) + SrcByteAdjust < 8);
11387
11408
assert(!DAG.getDataLayout().isBigEndian());
11388
- PermMask |= (PermOp.SrcOffset + SrcByteAdjust) << (i * 8);
11409
+ PermMask |= (( PermOp.SrcOffset % 4) + SrcByteAdjust) << (i * 8);
11389
11410
}
11390
-
11391
- SDValue Op = *PermNodes[FirstSrc].Src;
11392
- SDValue OtherOp = SecondSrc.has_value() ? *PermNodes[*SecondSrc].Src
11393
- : *PermNodes[FirstSrc].Src;
11394
-
11395
- // Check that we haven't just recreated the same FSHR node.
11396
- if (N->getOpcode() == ISD::FSHR &&
11397
- (N->getOperand(0) == Op || N->getOperand(0) == OtherOp) &&
11398
- (N->getOperand(1) == Op || N->getOperand(1) == OtherOp))
11399
- return SDValue();
11411
+ SDLoc DL(N);
11412
+ SDValue Op = *PermNodes[FirstSrc.first].Src;
11413
+ Op = getDWordFromOffset(DAG, DL, Op, FirstSrc.second);
11414
+ assert(Op.getValueSizeInBits() == 32);
11400
11415
11401
11416
// Check that we are not just extracting the bytes in order from an op
11402
- if (Op == OtherOp && Op.getValueSizeInBits() == 32 ) {
11417
+ if (!SecondSrc ) {
11403
11418
int Low16 = PermMask & 0xffff;
11404
11419
int Hi16 = (PermMask & 0xffff0000) >> 16;
11405
11420
@@ -11411,8 +11426,16 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
11411
11426
return DAG.getBitcast(MVT::getIntegerVT(32), Op);
11412
11427
}
11413
11428
11429
+ SDValue OtherOp =
11430
+ SecondSrc.has_value() ? *PermNodes[SecondSrc->first].Src : Op;
11431
+
11432
+ if (SecondSrc)
11433
+ OtherOp = getDWordFromOffset(DAG, DL, OtherOp, SecondSrc->second);
11434
+
11435
+ assert(Op.getValueSizeInBits() == 32);
11436
+
11414
11437
if (hasNon16BitAccesses(PermMask, Op, OtherOp)) {
11415
- SDLoc DL(N);
11438
+
11416
11439
assert(Op.getValueType().isByteSized() &&
11417
11440
OtherOp.getValueType().isByteSized());
11418
11441
@@ -11427,7 +11450,6 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
11427
11450
return DAG.getNode(AMDGPUISD::PERM, DL, MVT::i32, Op, OtherOp,
11428
11451
DAG.getConstant(PermMask, DL, MVT::i32));
11429
11452
}
11430
-
11431
11453
return SDValue();
11432
11454
}
11433
11455
@@ -12903,17 +12925,24 @@ static unsigned addPermMasks(unsigned First, unsigned Second) {
12903
12925
return (FirstNoCs | SecondNoCs) | (FirstCs & SecondCs);
12904
12926
}
12905
12927
12928
+ struct DotSrc {
12929
+ SDValue SrcOp;
12930
+ int64_t PermMask;
12931
+ int64_t DWordOffset;
12932
+ };
12933
+
12906
12934
static void placeSources(ByteProvider<SDValue> &Src0,
12907
12935
ByteProvider<SDValue> &Src1,
12908
- SmallVectorImpl<std::pair<SDValue, unsigned>> &Src0s,
12909
- SmallVectorImpl<std::pair<SDValue, unsigned>> &Src1s,
12910
- int Step) {
12936
+ SmallVectorImpl<DotSrc> &Src0s,
12937
+ SmallVectorImpl<DotSrc> &Src1s, int Step) {
12911
12938
12912
12939
assert(Src0.Src.has_value() && Src1.Src.has_value());
12913
12940
// Src0s and Src1s are empty, just place arbitrarily.
12914
12941
if (Step == 0) {
12915
- Src0s.push_back({*Src0.Src, (Src0.SrcOffset << 24) + 0x0c0c0c});
12916
- Src1s.push_back({*Src1.Src, (Src1.SrcOffset << 24) + 0x0c0c0c});
12942
+ Src0s.push_back({*Src0.Src, ((Src0.SrcOffset % 4) << 24) + 0x0c0c0c,
12943
+ Src0.SrcOffset / 4});
12944
+ Src1s.push_back({*Src1.Src, ((Src1.SrcOffset % 4) << 24) + 0x0c0c0c,
12945
+ Src1.SrcOffset / 4});
12917
12946
return;
12918
12947
}
12919
12948
@@ -12926,38 +12955,38 @@ static void placeSources(ByteProvider<SDValue> &Src0,
12926
12955
unsigned FMask = 0xFF << (8 * (3 - Step));
12927
12956
12928
12957
unsigned FirstMask =
12929
- BPP.first.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12958
+ ( BPP.first.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12930
12959
unsigned SecondMask =
12931
- BPP.second.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12960
+ ( BPP.second.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12932
12961
// Attempt to find Src vector which contains our SDValue, if so, add our
12933
12962
// perm mask to the existing one. If we are unable to find a match for the
12934
12963
// first SDValue, attempt to find match for the second.
12935
12964
int FirstGroup = -1;
12936
12965
for (int I = 0; I < 2; I++) {
12937
- SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
12938
- I == 0 ? Src0s : Src1s;
12939
- auto MatchesFirst = [& BPP](std::pair<SDValue, unsigned> IterElt) {
12940
- return IterElt.first == * BPP.first.Src ;
12966
+ SmallVectorImpl<DotSrc> &Srcs = I == 0 ? Src0s : Src1s;
12967
+ auto MatchesFirst = [&BPP](DotSrc &IterElt) {
12968
+ return IterElt.SrcOp == * BPP.first.Src &&
12969
+ ( IterElt.DWordOffset == ( BPP.first.SrcOffset / 4)) ;
12941
12970
};
12942
12971
12943
12972
auto Match = llvm::find_if(Srcs, MatchesFirst);
12944
12973
if (Match != Srcs.end()) {
12945
- Match->second = addPermMasks(FirstMask, Match->second );
12974
+ Match->PermMask = addPermMasks(FirstMask, Match->PermMask );
12946
12975
FirstGroup = I;
12947
12976
break;
12948
12977
}
12949
12978
}
12950
12979
if (FirstGroup != -1) {
12951
- SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
12952
- FirstGroup == 1 ? Src0s : Src1s;
12953
- auto MatchesSecond = [& BPP](std::pair<SDValue, unsigned> IterElt) {
12954
- return IterElt.first == * BPP.second.Src ;
12980
+ SmallVectorImpl<DotSrc> &Srcs = FirstGroup == 1 ? Src0s : Src1s;
12981
+ auto MatchesSecond = [&BPP](DotSrc &IterElt) {
12982
+ return IterElt.SrcOp == * BPP.second.Src &&
12983
+ ( IterElt.DWordOffset == ( BPP.second.SrcOffset / 4)) ;
12955
12984
};
12956
12985
auto Match = llvm::find_if(Srcs, MatchesSecond);
12957
12986
if (Match != Srcs.end()) {
12958
- Match->second = addPermMasks(SecondMask, Match->second );
12987
+ Match->PermMask = addPermMasks(SecondMask, Match->PermMask );
12959
12988
} else
12960
- Srcs.push_back({*BPP.second.Src, SecondMask});
12989
+ Srcs.push_back({*BPP.second.Src, SecondMask, BPP.second.SrcOffset / 4 });
12961
12990
return;
12962
12991
}
12963
12992
}
@@ -12969,29 +12998,32 @@ static void placeSources(ByteProvider<SDValue> &Src0,
12969
12998
unsigned FMask = 0xFF << (8 * (3 - Step));
12970
12999
12971
13000
Src0s.push_back(
12972
- {*Src0.Src, (Src0.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
13001
+ {*Src0.Src,
13002
+ ((Src0.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask)),
13003
+ Src1.SrcOffset / 4});
12973
13004
Src1s.push_back(
12974
- {*Src1.Src, (Src1.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
13005
+ {*Src1.Src,
13006
+ ((Src1.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask)),
13007
+ Src1.SrcOffset / 4});
12975
13008
12976
13009
return;
12977
13010
}
12978
13011
12979
- static SDValue
12980
- resolveSources(SelectionDAG &DAG, SDLoc SL,
12981
- SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
12982
- bool IsSigned, bool IsAny) {
13012
+ static SDValue resolveSources(SelectionDAG &DAG, SDLoc SL,
13013
+ SmallVectorImpl<DotSrc> &Srcs, bool IsSigned,
13014
+ bool IsAny) {
12983
13015
12984
13016
// If we just have one source, just permute it accordingly.
12985
13017
if (Srcs.size() == 1) {
12986
13018
auto Elt = Srcs.begin();
12987
- auto EltVal = DAG.getBitcastedAnyExtOrTrunc(Elt->first , SL, MVT::i32 );
13019
+ auto EltOp = getDWordFromOffset(DAG , SL, Elt->SrcOp, Elt->DWordOffset );
12988
13020
12989
- // v_perm will produce the original value.
12990
- if (Elt->second == 0x3020100)
12991
- return EltVal ;
13021
+ // v_perm will produce the original value
13022
+ if (Elt->PermMask == 0x3020100)
13023
+ return EltOp ;
12992
13024
12993
- return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal ,
12994
- DAG.getConstant(Elt->second , SL, MVT::i32));
13025
+ return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltOp, EltOp ,
13026
+ DAG.getConstant(Elt->PermMask , SL, MVT::i32));
12995
13027
}
12996
13028
12997
13029
auto FirstElt = Srcs.begin();
@@ -13002,8 +13034,8 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
13002
13034
// If we have multiple sources in the chain, combine them via perms (using
13003
13035
// calculated perm mask) and Ors.
13004
13036
while (true) {
13005
- auto FirstMask = FirstElt->second ;
13006
- auto SecondMask = SecondElt->second ;
13037
+ auto FirstMask = FirstElt->PermMask ;
13038
+ auto SecondMask = SecondElt->PermMask ;
13007
13039
13008
13040
unsigned FirstCs = FirstMask & 0x0c0c0c0c;
13009
13041
unsigned FirstPlusFour = FirstMask | 0x04040404;
@@ -13013,9 +13045,9 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
13013
13045
13014
13046
auto PermMask = addPermMasks(FirstMask, SecondMask);
13015
13047
auto FirstVal =
13016
- DAG.getBitcastedAnyExtOrTrunc(FirstElt->first , SL, MVT::i32 );
13048
+ getDWordFromOffset(DAG , SL, FirstElt->SrcOp, FirstElt->DWordOffset );
13017
13049
auto SecondVal =
13018
- DAG.getBitcastedAnyExtOrTrunc(SecondElt->first , SL, MVT::i32 );
13050
+ getDWordFromOffset(DAG , SL, SecondElt->SrcOp, SecondElt->DWordOffset );
13019
13051
13020
13052
Perms.push_back(DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, FirstVal,
13021
13053
SecondVal,
@@ -13029,12 +13061,12 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
13029
13061
// If we only have a FirstElt, then just combine that into the cumulative
13030
13062
// source node.
13031
13063
if (SecondElt == Srcs.end()) {
13032
- auto EltVal =
13033
- DAG.getBitcastedAnyExtOrTrunc(FirstElt->first , SL, MVT::i32 );
13064
+ auto EltOp =
13065
+ getDWordFromOffset(DAG , SL, FirstElt->SrcOp, FirstElt->DWordOffset );
13034
13066
13035
13067
Perms.push_back(
13036
- DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal ,
13037
- DAG.getConstant(FirstElt->second , SL, MVT::i32)));
13068
+ DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltOp, EltOp ,
13069
+ DAG.getConstant(FirstElt->PermMask , SL, MVT::i32)));
13038
13070
break;
13039
13071
}
13040
13072
}
@@ -13045,9 +13077,8 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
13045
13077
: Perms[0];
13046
13078
}
13047
13079
13048
- static void fixMasks(SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
13049
- unsigned ChainLength) {
13050
- for (auto &[EntryVal, EntryMask] : Srcs) {
13080
+ static void fixMasks(SmallVectorImpl<DotSrc> &Srcs, unsigned ChainLength) {
13081
+ for (auto &[EntryVal, EntryMask, EntryOffset] : Srcs) {
13051
13082
EntryMask = EntryMask >> ((4 - ChainLength) * 8);
13052
13083
auto ZeroMask = ChainLength == 2 ? 0x0c0c0000 : 0x0c000000;
13053
13084
EntryMask += ZeroMask;
@@ -13112,8 +13143,8 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
13112
13143
(Subtarget->hasDot1Insts() || Subtarget->hasDot8Insts())) {
13113
13144
SDValue TempNode(N, 0);
13114
13145
std::optional<bool> IsSigned;
13115
- SmallVector<std::pair<SDValue, unsigned> , 4> Src0s;
13116
- SmallVector<std::pair<SDValue, unsigned> , 4> Src1s;
13146
+ SmallVector<DotSrc , 4> Src0s;
13147
+ SmallVector<DotSrc , 4> Src1s;
13117
13148
SmallVector<SDValue, 4> Src2s;
13118
13149
13119
13150
// Match the v_dot4 tree, while collecting src nodes.
@@ -13191,11 +13222,11 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
13191
13222
// (commutation).
13192
13223
bool UseOriginalSrc = false;
13193
13224
if (ChainLength == 4 && Src0s.size() == 1 && Src1s.size() == 1 &&
13194
- Src0s.begin()->second == Src1s.begin()->second &&
13195
- Src0s.begin()->first .getValueSizeInBits() = = 32 &&
13196
- Src1s.begin()->first .getValueSizeInBits() = = 32) {
13225
+ Src0s.begin()->PermMask == Src1s.begin()->PermMask &&
13226
+ Src0s.begin()->SrcOp .getValueSizeInBits() > = 32 &&
13227
+ Src1s.begin()->SrcOp .getValueSizeInBits() > = 32) {
13197
13228
SmallVector<unsigned, 4> SrcBytes;
13198
- auto Src0Mask = Src0s.begin()->second ;
13229
+ auto Src0Mask = Src0s.begin()->PermMask ;
13199
13230
SrcBytes.push_back(Src0Mask & 0xFF000000);
13200
13231
bool UniqueEntries = true;
13201
13232
for (auto I = 1; I < 4; I++) {
@@ -13210,11 +13241,19 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
13210
13241
13211
13242
if (UniqueEntries) {
13212
13243
UseOriginalSrc = true;
13213
- // Must be 32 bits to enter above conditional.
13214
- assert(Src0s.begin()->first.getValueSizeInBits() == 32);
13215
- assert(Src1s.begin()->first.getValueSizeInBits() == 32);
13216
- Src0 = DAG.getBitcast(MVT::getIntegerVT(32), Src0s.begin()->first);
13217
- Src1 = DAG.getBitcast(MVT::getIntegerVT(32), Src1s.begin()->first);
13244
+
13245
+ auto FirstElt = Src0s.begin();
13246
+ auto FirstEltOp =
13247
+ getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
13248
+
13249
+ auto SecondElt = Src1s.begin();
13250
+ auto SecondEltOp = getDWordFromOffset(DAG, SL, SecondElt->SrcOp,
13251
+ SecondElt->DWordOffset);
13252
+
13253
+ Src0 = DAG.getBitcastedAnyExtOrTrunc(FirstEltOp, SL,
13254
+ MVT::getIntegerVT(32));
13255
+ Src1 = DAG.getBitcastedAnyExtOrTrunc(SecondEltOp, SL,
13256
+ MVT::getIntegerVT(32));
13218
13257
}
13219
13258
}
13220
13259
0 commit comments