@@ -11567,8 +11567,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
11567
11567
if (Depth >= 6)
11568
11568
return std::nullopt;
11569
11569
11570
- auto ValueSize = Op.getValueSizeInBits();
11571
- if (ValueSize != 8 && ValueSize != 16 && ValueSize != 32)
11570
+ if (Op.getValueSizeInBits() < 8)
11572
11571
return std::nullopt;
11573
11572
11574
11573
switch (Op->getOpcode()) {
@@ -11827,8 +11826,6 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11827
11826
auto VecIdx = IdxOp->getZExtValue();
11828
11827
auto ScalarSize = Op.getScalarValueSizeInBits();
11829
11828
if (ScalarSize != 32) {
11830
- if ((VecIdx + 1) * ScalarSize > 32)
11831
- return std::nullopt;
11832
11829
Index = ScalarSize == 8 ? VecIdx : VecIdx * 2 + Index;
11833
11830
}
11834
11831
@@ -11913,9 +11910,6 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
11913
11910
int Low16 = PermMask & 0xffff;
11914
11911
int Hi16 = (PermMask & 0xffff0000) >> 16;
11915
11912
11916
- assert(Op.getValueType().isByteSized());
11917
- assert(OtherOp.getValueType().isByteSized());
11918
-
11919
11913
auto TempOp = peekThroughBitcasts(Op);
11920
11914
auto TempOtherOp = peekThroughBitcasts(OtherOp);
11921
11915
@@ -11933,15 +11927,38 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
11933
11927
return !addresses16Bits(Low16) || !addresses16Bits(Hi16);
11934
11928
}
11935
11929
11930
+ static SDValue getDWordFromOffset(SelectionDAG &DAG, SDLoc SL, SDValue Src,
11931
+ unsigned DWordOffset) {
11932
+ SDValue Ret;
11933
+ if (Src.getValueSizeInBits() <= 32)
11934
+ return DAG.getBitcastedAnyExtOrTrunc(Src, SL, MVT::i32);
11935
+
11936
+ if (Src.getValueSizeInBits() >= 256) {
11937
+ assert(!(Src.getValueSizeInBits() % 32));
11938
+ Ret = DAG.getBitcast(
11939
+ MVT::getVectorVT(MVT::i32, Src.getValueSizeInBits() / 32), Src);
11940
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, Ret,
11941
+ DAG.getConstant(DWordOffset, SL, MVT::i32));
11942
+ }
11943
+
11944
+ Ret = DAG.getBitcastedAnyExtOrTrunc(
11945
+ Src, SL, MVT::getIntegerVT(Src.getValueSizeInBits()));
11946
+ if (DWordOffset) {
11947
+ auto Shifted = DAG.getNode(ISD::SRL, SL, Ret.getValueType(), Ret,
11948
+ DAG.getConstant(DWordOffset * 32, SL, MVT::i32));
11949
+ return DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, Shifted);
11950
+ }
11951
+
11952
+ return DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
11953
+ }
11954
+
11936
11955
static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
11937
11956
SelectionDAG &DAG = DCI.DAG;
11938
11957
EVT VT = N->getValueType(0);
11939
-
11940
- if (VT != MVT::i32)
11941
- return SDValue();
11958
+ SmallVector<ByteProvider<SDValue>, 8> PermNodes;
11942
11959
11943
11960
// VT is known to be MVT::i32, so we need to provide 4 bytes.
11944
- SmallVector<ByteProvider<SDValue>, 8> PermNodes ;
11961
+ assert(VT == MVT::i32) ;
11945
11962
for (int i = 0; i < 4; i++) {
11946
11963
// Find the ByteProvider that provides the ith byte of the result of OR
11947
11964
std::optional<ByteProvider<SDValue>> P =
@@ -11955,42 +11972,40 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
11955
11972
if (PermNodes.size() != 4)
11956
11973
return SDValue();
11957
11974
11958
- int FirstSrc = 0 ;
11959
- std::optional<int > SecondSrc;
11975
+ std::pair<unsigned, unsigned> FirstSrc(0, PermNodes[0].SrcOffset / 4) ;
11976
+ std::optional<std::pair<unsigned, unsigned> > SecondSrc;
11960
11977
uint64_t PermMask = 0x00000000;
11961
11978
for (size_t i = 0; i < PermNodes.size(); i++) {
11962
11979
auto PermOp = PermNodes[i];
11963
11980
// Since the mask is applied to Src1:Src2, Src1 bytes must be offset
11964
11981
// by sizeof(Src2) = 4
11965
11982
int SrcByteAdjust = 4;
11966
11983
11967
- if (!PermOp.hasSameSrc(PermNodes[FirstSrc])) {
11968
- if (SecondSrc.has_value())
11969
- if (!PermOp.hasSameSrc(PermNodes[*SecondSrc]))
11984
+ // If the Src uses a byte from a different DWORD, then it corresponds
11985
+ // with a difference source
11986
+ if (!PermOp.hasSameSrc(PermNodes[FirstSrc.first]) ||
11987
+ ((PermOp.SrcOffset / 4) != FirstSrc.second)) {
11988
+ if (SecondSrc)
11989
+ if (!PermOp.hasSameSrc(PermNodes[SecondSrc->first]) ||
11990
+ ((PermOp.SrcOffset / 4) != SecondSrc->second))
11970
11991
return SDValue();
11971
11992
11972
11993
// Set the index of the second distinct Src node
11973
- SecondSrc = i ;
11974
- assert(!(PermNodes[* SecondSrc].Src->getValueSizeInBits() % 8));
11994
+ SecondSrc = {i, PermNodes[i].SrcOffset / 4} ;
11995
+ assert(!(PermNodes[SecondSrc->first ].Src->getValueSizeInBits() % 8));
11975
11996
SrcByteAdjust = 0;
11976
11997
}
11977
- assert(PermOp.SrcOffset + SrcByteAdjust < 8);
11998
+ assert(( PermOp.SrcOffset % 4) + SrcByteAdjust < 8);
11978
11999
assert(!DAG.getDataLayout().isBigEndian());
11979
- PermMask |= (PermOp.SrcOffset + SrcByteAdjust) << (i * 8);
12000
+ PermMask |= (( PermOp.SrcOffset % 4) + SrcByteAdjust) << (i * 8);
11980
12001
}
11981
-
11982
- SDValue Op = *PermNodes[FirstSrc].Src;
11983
- SDValue OtherOp = SecondSrc.has_value() ? *PermNodes[*SecondSrc].Src
11984
- : *PermNodes[FirstSrc].Src;
11985
-
11986
- // Check that we haven't just recreated the same FSHR node.
11987
- if (N->getOpcode() == ISD::FSHR &&
11988
- (N->getOperand(0) == Op || N->getOperand(0) == OtherOp) &&
11989
- (N->getOperand(1) == Op || N->getOperand(1) == OtherOp))
11990
- return SDValue();
12002
+ SDLoc DL(N);
12003
+ SDValue Op = *PermNodes[FirstSrc.first].Src;
12004
+ Op = getDWordFromOffset(DAG, DL, Op, FirstSrc.second);
12005
+ assert(Op.getValueSizeInBits() == 32);
11991
12006
11992
12007
// Check that we are not just extracting the bytes in order from an op
11993
- if (Op == OtherOp && Op.getValueSizeInBits() == 32 ) {
12008
+ if (!SecondSrc ) {
11994
12009
int Low16 = PermMask & 0xffff;
11995
12010
int Hi16 = (PermMask & 0xffff0000) >> 16;
11996
12011
@@ -12002,8 +12017,16 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
12002
12017
return DAG.getBitcast(MVT::getIntegerVT(32), Op);
12003
12018
}
12004
12019
12020
+ SDValue OtherOp =
12021
+ SecondSrc.has_value() ? *PermNodes[SecondSrc->first].Src : Op;
12022
+
12023
+ if (SecondSrc)
12024
+ OtherOp = getDWordFromOffset(DAG, DL, OtherOp, SecondSrc->second);
12025
+
12026
+ assert(Op.getValueSizeInBits() == 32);
12027
+
12005
12028
if (hasNon16BitAccesses(PermMask, Op, OtherOp)) {
12006
- SDLoc DL(N);
12029
+
12007
12030
assert(Op.getValueType().isByteSized() &&
12008
12031
OtherOp.getValueType().isByteSized());
12009
12032
@@ -12018,7 +12041,6 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
12018
12041
return DAG.getNode(AMDGPUISD::PERM, DL, MVT::i32, Op, OtherOp,
12019
12042
DAG.getConstant(PermMask, DL, MVT::i32));
12020
12043
}
12021
-
12022
12044
return SDValue();
12023
12045
}
12024
12046
@@ -13530,17 +13552,24 @@ static unsigned addPermMasks(unsigned First, unsigned Second) {
13530
13552
return (FirstNoCs | SecondNoCs) | (FirstCs & SecondCs);
13531
13553
}
13532
13554
13555
+ struct DotSrc {
13556
+ SDValue SrcOp;
13557
+ int64_t PermMask;
13558
+ int64_t DWordOffset;
13559
+ };
13560
+
13533
13561
static void placeSources(ByteProvider<SDValue> &Src0,
13534
13562
ByteProvider<SDValue> &Src1,
13535
- SmallVectorImpl<std::pair<SDValue, unsigned>> &Src0s,
13536
- SmallVectorImpl<std::pair<SDValue, unsigned>> &Src1s,
13537
- int Step) {
13563
+ SmallVectorImpl<DotSrc> &Src0s,
13564
+ SmallVectorImpl<DotSrc> &Src1s, int Step) {
13538
13565
13539
13566
assert(Src0.Src.has_value() && Src1.Src.has_value());
13540
13567
// Src0s and Src1s are empty, just place arbitrarily.
13541
13568
if (Step == 0) {
13542
- Src0s.push_back({*Src0.Src, (Src0.SrcOffset << 24) + 0x0c0c0c});
13543
- Src1s.push_back({*Src1.Src, (Src1.SrcOffset << 24) + 0x0c0c0c});
13569
+ Src0s.push_back({*Src0.Src, ((Src0.SrcOffset % 4) << 24) + 0x0c0c0c,
13570
+ Src0.SrcOffset / 4});
13571
+ Src1s.push_back({*Src1.Src, ((Src1.SrcOffset % 4) << 24) + 0x0c0c0c,
13572
+ Src1.SrcOffset / 4});
13544
13573
return;
13545
13574
}
13546
13575
@@ -13553,38 +13582,38 @@ static void placeSources(ByteProvider<SDValue> &Src0,
13553
13582
unsigned FMask = 0xFF << (8 * (3 - Step));
13554
13583
13555
13584
unsigned FirstMask =
13556
- BPP.first.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
13585
+ ( BPP.first.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask);
13557
13586
unsigned SecondMask =
13558
- BPP.second.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
13587
+ ( BPP.second.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask);
13559
13588
// Attempt to find Src vector which contains our SDValue, if so, add our
13560
13589
// perm mask to the existing one. If we are unable to find a match for the
13561
13590
// first SDValue, attempt to find match for the second.
13562
13591
int FirstGroup = -1;
13563
13592
for (int I = 0; I < 2; I++) {
13564
- SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
13565
- I == 0 ? Src0s : Src1s;
13566
- auto MatchesFirst = [& BPP](std::pair<SDValue, unsigned> IterElt) {
13567
- return IterElt.first == * BPP.first.Src ;
13593
+ SmallVectorImpl<DotSrc> &Srcs = I == 0 ? Src0s : Src1s;
13594
+ auto MatchesFirst = [&BPP](DotSrc &IterElt) {
13595
+ return IterElt.SrcOp == * BPP.first.Src &&
13596
+ ( IterElt.DWordOffset == ( BPP.first.SrcOffset / 4)) ;
13568
13597
};
13569
13598
13570
13599
auto Match = llvm::find_if(Srcs, MatchesFirst);
13571
13600
if (Match != Srcs.end()) {
13572
- Match->second = addPermMasks(FirstMask, Match->second );
13601
+ Match->PermMask = addPermMasks(FirstMask, Match->PermMask );
13573
13602
FirstGroup = I;
13574
13603
break;
13575
13604
}
13576
13605
}
13577
13606
if (FirstGroup != -1) {
13578
- SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
13579
- FirstGroup == 1 ? Src0s : Src1s;
13580
- auto MatchesSecond = [& BPP](std::pair<SDValue, unsigned> IterElt) {
13581
- return IterElt.first == * BPP.second.Src ;
13607
+ SmallVectorImpl<DotSrc> &Srcs = FirstGroup == 1 ? Src0s : Src1s;
13608
+ auto MatchesSecond = [&BPP](DotSrc &IterElt) {
13609
+ return IterElt.SrcOp == * BPP.second.Src &&
13610
+ ( IterElt.DWordOffset == ( BPP.second.SrcOffset / 4)) ;
13582
13611
};
13583
13612
auto Match = llvm::find_if(Srcs, MatchesSecond);
13584
13613
if (Match != Srcs.end()) {
13585
- Match->second = addPermMasks(SecondMask, Match->second );
13614
+ Match->PermMask = addPermMasks(SecondMask, Match->PermMask );
13586
13615
} else
13587
- Srcs.push_back({*BPP.second.Src, SecondMask});
13616
+ Srcs.push_back({*BPP.second.Src, SecondMask, BPP.second.SrcOffset / 4 });
13588
13617
return;
13589
13618
}
13590
13619
}
@@ -13596,29 +13625,32 @@ static void placeSources(ByteProvider<SDValue> &Src0,
13596
13625
unsigned FMask = 0xFF << (8 * (3 - Step));
13597
13626
13598
13627
Src0s.push_back(
13599
- {*Src0.Src, (Src0.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
13628
+ {*Src0.Src,
13629
+ ((Src0.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask)),
13630
+ Src1.SrcOffset / 4});
13600
13631
Src1s.push_back(
13601
- {*Src1.Src, (Src1.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
13632
+ {*Src1.Src,
13633
+ ((Src1.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask)),
13634
+ Src1.SrcOffset / 4});
13602
13635
13603
13636
return;
13604
13637
}
13605
13638
13606
- static SDValue
13607
- resolveSources(SelectionDAG &DAG, SDLoc SL,
13608
- SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
13609
- bool IsSigned, bool IsAny) {
13639
+ static SDValue resolveSources(SelectionDAG &DAG, SDLoc SL,
13640
+ SmallVectorImpl<DotSrc> &Srcs, bool IsSigned,
13641
+ bool IsAny) {
13610
13642
13611
13643
// If we just have one source, just permute it accordingly.
13612
13644
if (Srcs.size() == 1) {
13613
13645
auto Elt = Srcs.begin();
13614
- auto EltVal = DAG.getBitcastedAnyExtOrTrunc(Elt->first , SL, MVT::i32 );
13646
+ auto EltOp = getDWordFromOffset(DAG , SL, Elt->SrcOp, Elt->DWordOffset );
13615
13647
13616
- // v_perm will produce the original value.
13617
- if (Elt->second == 0x3020100)
13618
- return EltVal ;
13648
+ // v_perm will produce the original value
13649
+ if (Elt->PermMask == 0x3020100)
13650
+ return EltOp ;
13619
13651
13620
- return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal ,
13621
- DAG.getConstant(Elt->second , SL, MVT::i32));
13652
+ return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltOp, EltOp ,
13653
+ DAG.getConstant(Elt->PermMask , SL, MVT::i32));
13622
13654
}
13623
13655
13624
13656
auto FirstElt = Srcs.begin();
@@ -13629,8 +13661,8 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
13629
13661
// If we have multiple sources in the chain, combine them via perms (using
13630
13662
// calculated perm mask) and Ors.
13631
13663
while (true) {
13632
- auto FirstMask = FirstElt->second ;
13633
- auto SecondMask = SecondElt->second ;
13664
+ auto FirstMask = FirstElt->PermMask ;
13665
+ auto SecondMask = SecondElt->PermMask ;
13634
13666
13635
13667
unsigned FirstCs = FirstMask & 0x0c0c0c0c;
13636
13668
unsigned FirstPlusFour = FirstMask | 0x04040404;
@@ -13640,9 +13672,9 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
13640
13672
13641
13673
auto PermMask = addPermMasks(FirstMask, SecondMask);
13642
13674
auto FirstVal =
13643
- DAG.getBitcastedAnyExtOrTrunc(FirstElt->first , SL, MVT::i32 );
13675
+ getDWordFromOffset(DAG , SL, FirstElt->SrcOp, FirstElt->DWordOffset );
13644
13676
auto SecondVal =
13645
- DAG.getBitcastedAnyExtOrTrunc(SecondElt->first , SL, MVT::i32 );
13677
+ getDWordFromOffset(DAG , SL, SecondElt->SrcOp, SecondElt->DWordOffset );
13646
13678
13647
13679
Perms.push_back(DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, FirstVal,
13648
13680
SecondVal,
@@ -13656,12 +13688,12 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
13656
13688
// If we only have a FirstElt, then just combine that into the cumulative
13657
13689
// source node.
13658
13690
if (SecondElt == Srcs.end()) {
13659
- auto EltVal =
13660
- DAG.getBitcastedAnyExtOrTrunc(FirstElt->first , SL, MVT::i32 );
13691
+ auto EltOp =
13692
+ getDWordFromOffset(DAG , SL, FirstElt->SrcOp, FirstElt->DWordOffset );
13661
13693
13662
13694
Perms.push_back(
13663
- DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal ,
13664
- DAG.getConstant(FirstElt->second , SL, MVT::i32)));
13695
+ DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltOp, EltOp ,
13696
+ DAG.getConstant(FirstElt->PermMask , SL, MVT::i32)));
13665
13697
break;
13666
13698
}
13667
13699
}
@@ -13672,9 +13704,8 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
13672
13704
: Perms[0];
13673
13705
}
13674
13706
13675
- static void fixMasks(SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
13676
- unsigned ChainLength) {
13677
- for (auto &[EntryVal, EntryMask] : Srcs) {
13707
+ static void fixMasks(SmallVectorImpl<DotSrc> &Srcs, unsigned ChainLength) {
13708
+ for (auto &[EntryVal, EntryMask, EntryOffset] : Srcs) {
13678
13709
EntryMask = EntryMask >> ((4 - ChainLength) * 8);
13679
13710
auto ZeroMask = ChainLength == 2 ? 0x0c0c0000 : 0x0c000000;
13680
13711
EntryMask += ZeroMask;
@@ -13774,8 +13805,8 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
13774
13805
(Subtarget->hasDot1Insts() || Subtarget->hasDot8Insts())) {
13775
13806
SDValue TempNode(N, 0);
13776
13807
std::optional<bool> IsSigned;
13777
- SmallVector<std::pair<SDValue, unsigned> , 4> Src0s;
13778
- SmallVector<std::pair<SDValue, unsigned> , 4> Src1s;
13808
+ SmallVector<DotSrc , 4> Src0s;
13809
+ SmallVector<DotSrc , 4> Src1s;
13779
13810
SmallVector<SDValue, 4> Src2s;
13780
13811
13781
13812
// Match the v_dot4 tree, while collecting src nodes.
@@ -13857,11 +13888,11 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
13857
13888
// (commutation).
13858
13889
bool UseOriginalSrc = false;
13859
13890
if (ChainLength == 4 && Src0s.size() == 1 && Src1s.size() == 1 &&
13860
- Src0s.begin()->second == Src1s.begin()->second &&
13861
- Src0s.begin()->first .getValueSizeInBits() = = 32 &&
13862
- Src1s.begin()->first .getValueSizeInBits() = = 32) {
13891
+ Src0s.begin()->PermMask == Src1s.begin()->PermMask &&
13892
+ Src0s.begin()->SrcOp .getValueSizeInBits() > = 32 &&
13893
+ Src1s.begin()->SrcOp .getValueSizeInBits() > = 32) {
13863
13894
SmallVector<unsigned, 4> SrcBytes;
13864
- auto Src0Mask = Src0s.begin()->second ;
13895
+ auto Src0Mask = Src0s.begin()->PermMask ;
13865
13896
SrcBytes.push_back(Src0Mask & 0xFF000000);
13866
13897
bool UniqueEntries = true;
13867
13898
for (auto I = 1; I < 4; I++) {
@@ -13876,11 +13907,19 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
13876
13907
13877
13908
if (UniqueEntries) {
13878
13909
UseOriginalSrc = true;
13879
- // Must be 32 bits to enter above conditional.
13880
- assert(Src0s.begin()->first.getValueSizeInBits() == 32);
13881
- assert(Src1s.begin()->first.getValueSizeInBits() == 32);
13882
- Src0 = DAG.getBitcast(MVT::getIntegerVT(32), Src0s.begin()->first);
13883
- Src1 = DAG.getBitcast(MVT::getIntegerVT(32), Src1s.begin()->first);
13910
+
13911
+ auto FirstElt = Src0s.begin();
13912
+ auto FirstEltOp =
13913
+ getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
13914
+
13915
+ auto SecondElt = Src1s.begin();
13916
+ auto SecondEltOp = getDWordFromOffset(DAG, SL, SecondElt->SrcOp,
13917
+ SecondElt->DWordOffset);
13918
+
13919
+ Src0 = DAG.getBitcastedAnyExtOrTrunc(FirstEltOp, SL,
13920
+ MVT::getIntegerVT(32));
13921
+ Src1 = DAG.getBitcastedAnyExtOrTrunc(SecondEltOp, SL,
13922
+ MVT::getIntegerVT(32));
13884
13923
}
13885
13924
}
13886
13925
0 commit comments