@@ -11570,6 +11570,9 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
11570
11570
if (Op.getValueSizeInBits() < 8)
11571
11571
return std::nullopt;
11572
11572
11573
+ if (Op.getValueType().isVector())
11574
+ return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex);
11575
+
11573
11576
switch (Op->getOpcode()) {
11574
11577
case ISD::TRUNCATE: {
11575
11578
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
@@ -11636,8 +11639,12 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11636
11639
if (Index > BitWidth / 8 - 1)
11637
11640
return std::nullopt;
11638
11641
11642
+ bool IsVec = Op.getValueType().isVector();
11639
11643
switch (Op.getOpcode()) {
11640
11644
case ISD::OR: {
11645
+ if (IsVec)
11646
+ return std::nullopt;
11647
+
11641
11648
auto RHS = calculateByteProvider(Op.getOperand(1), Index, Depth + 1,
11642
11649
StartingIndex);
11643
11650
if (!RHS)
@@ -11658,6 +11665,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11658
11665
}
11659
11666
11660
11667
case ISD::AND: {
11668
+ if (IsVec)
11669
+ return std::nullopt;
11670
+
11661
11671
auto BitMaskOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
11662
11672
if (!BitMaskOp)
11663
11673
return std::nullopt;
@@ -11678,6 +11688,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11678
11688
}
11679
11689
11680
11690
case ISD::FSHR: {
11691
+ if (IsVec)
11692
+ return std::nullopt;
11693
+
11681
11694
// fshr(X,Y,Z): (X << (BW - (Z % BW))) | (Y >> (Z % BW))
11682
11695
auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(2));
11683
11696
if (!ShiftOp || Op.getValueType().isVector())
@@ -11703,6 +11716,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11703
11716
11704
11717
case ISD::SRA:
11705
11718
case ISD::SRL: {
11719
+ if (IsVec)
11720
+ return std::nullopt;
11721
+
11706
11722
auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
11707
11723
if (!ShiftOp)
11708
11724
return std::nullopt;
@@ -11728,6 +11744,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11728
11744
}
11729
11745
11730
11746
case ISD::SHL: {
11747
+ if (IsVec)
11748
+ return std::nullopt;
11749
+
11731
11750
auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
11732
11751
if (!ShiftOp)
11733
11752
return std::nullopt;
@@ -11752,6 +11771,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11752
11771
case ISD::SIGN_EXTEND_INREG:
11753
11772
case ISD::AssertZext:
11754
11773
case ISD::AssertSext: {
11774
+ if (IsVec)
11775
+ return std::nullopt;
11776
+
11755
11777
SDValue NarrowOp = Op->getOperand(0);
11756
11778
unsigned NarrowBitWidth = NarrowOp.getValueSizeInBits();
11757
11779
if (Op->getOpcode() == ISD::SIGN_EXTEND_INREG ||
@@ -11773,6 +11795,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11773
11795
}
11774
11796
11775
11797
case ISD::TRUNCATE: {
11798
+ if (IsVec)
11799
+ return std::nullopt;
11800
+
11776
11801
uint64_t NarrowByteWidth = BitWidth / 8;
11777
11802
11778
11803
if (NarrowByteWidth >= Index) {
@@ -11815,9 +11840,13 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11815
11840
return std::nullopt;
11816
11841
}
11817
11842
11818
- case ISD::BSWAP:
11843
+ case ISD::BSWAP: {
11844
+ if (IsVec)
11845
+ return std::nullopt;
11846
+
11819
11847
return calculateByteProvider(Op->getOperand(0), BitWidth / 8 - Index - 1,
11820
11848
Depth + 1, StartingIndex);
11849
+ }
11821
11850
11822
11851
case ISD::EXTRACT_VECTOR_ELT: {
11823
11852
auto IdxOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
@@ -11834,6 +11863,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
11834
11863
}
11835
11864
11836
11865
case AMDGPUISD::PERM: {
11866
+ if (IsVec)
11867
+ return std::nullopt;
11868
+
11837
11869
auto PermMask = dyn_cast<ConstantSDNode>(Op->getOperand(2));
11838
11870
if (!PermMask)
11839
11871
return std::nullopt;
@@ -11930,25 +11962,55 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
11930
11962
static SDValue getDWordFromOffset(SelectionDAG &DAG, SDLoc SL, SDValue Src,
11931
11963
unsigned DWordOffset) {
11932
11964
SDValue Ret;
11933
- if (Src.getValueSizeInBits() <= 32)
11934
- return DAG.getBitcastedAnyExtOrTrunc(Src, SL, MVT::i32);
11935
11965
11936
- if (Src.getValueSizeInBits() >= 256) {
11937
- assert(!(Src.getValueSizeInBits() % 32));
11938
- Ret = DAG.getBitcast(
11939
- MVT::getVectorVT(MVT::i32, Src.getValueSizeInBits() / 32), Src);
11940
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, Ret,
11941
- DAG.getConstant(DWordOffset, SL, MVT::i32));
11942
- }
11966
+ auto TypeSize = Src.getValueSizeInBits().getFixedValue();
11967
+ // ByteProvider must be at least 8 bits
11968
+ assert(Src.getValueSizeInBits().isKnownMultipleOf(8));
11943
11969
11944
- Ret = DAG.getBitcastedAnyExtOrTrunc(
11945
- Src, SL, MVT::getIntegerVT(Src.getValueSizeInBits()));
11946
- if (DWordOffset) {
11947
- auto Shifted = DAG.getNode(ISD::SRL, SL, Ret.getValueType(), Ret,
11948
- DAG.getConstant(DWordOffset * 32, SL, MVT::i32));
11949
- return DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, Shifted);
11950
- }
11970
+ if (TypeSize <= 32)
11971
+ return DAG.getBitcastedAnyExtOrTrunc(Src, SL, MVT::i32);
11951
11972
11973
+ if (Src.getValueType().isVector()) {
11974
+ auto ScalarTySize = Src.getScalarValueSizeInBits();
11975
+ auto ScalarTy = Src.getValueType().getScalarType();
11976
+ if (ScalarTySize == 32) {
11977
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, Src,
11978
+ DAG.getConstant(DWordOffset, SL, MVT::i32));
11979
+ }
11980
+ if (ScalarTySize > 32) {
11981
+ Ret = DAG.getNode(
11982
+ ISD::EXTRACT_VECTOR_ELT, SL, ScalarTy, Src,
11983
+ DAG.getConstant(DWordOffset / (ScalarTySize / 32), SL, MVT::i32));
11984
+ auto ShiftVal = 32 * (DWordOffset % (ScalarTySize / 32));
11985
+ if (ShiftVal)
11986
+ Ret = DAG.getNode(ISD::SRL, SL, Ret.getValueType(), Ret,
11987
+ DAG.getConstant(ShiftVal, SL, MVT::i32));
11988
+ return DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
11989
+ }
11990
+
11991
+ assert(ScalarTySize < 32);
11992
+ auto NumElements = TypeSize / ScalarTySize;
11993
+ auto Trunc32Elements = (ScalarTySize * NumElements) / 32;
11994
+ auto NormalizedTrunc = Trunc32Elements * 32 / ScalarTySize;
11995
+ auto NumElementsIn32 = 32 / ScalarTySize;
11996
+ auto NumAvailElements = DWordOffset < Trunc32Elements
11997
+ ? NumElementsIn32
11998
+ : NumElements - NormalizedTrunc;
11999
+
12000
+ SmallVector<SDValue, 4> VecSrcs;
12001
+ DAG.ExtractVectorElements(Src, VecSrcs, DWordOffset * NumElementsIn32,
12002
+ NumAvailElements);
12003
+
12004
+ Ret = DAG.getBuildVector(
12005
+ MVT::getVectorVT(MVT::getIntegerVT(ScalarTySize), NumAvailElements), SL,
12006
+ VecSrcs);
12007
+ return Ret = DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
12008
+ }
12009
+
12010
+ /// Scalar Type
12011
+ auto ShiftVal = 32 * DWordOffset;
12012
+ Ret = DAG.getNode(ISD::SRL, SL, Src.getValueType(), Src,
12013
+ DAG.getConstant(ShiftVal, SL, MVT::i32));
11952
12014
return DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
11953
12015
}
11954
12016
@@ -12017,13 +12079,12 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
12017
12079
return DAG.getBitcast(MVT::getIntegerVT(32), Op);
12018
12080
}
12019
12081
12020
- SDValue OtherOp =
12021
- SecondSrc.has_value() ? *PermNodes[SecondSrc->first].Src : Op;
12082
+ SDValue OtherOp = SecondSrc ? *PermNodes[SecondSrc->first].Src : Op;
12022
12083
12023
- if (SecondSrc)
12084
+ if (SecondSrc) {
12024
12085
OtherOp = getDWordFromOffset(DAG, DL, OtherOp, SecondSrc->second);
12025
-
12026
- assert(Op.getValueSizeInBits() == 32);
12086
+ assert(OtherOp.getValueSizeInBits() == 32);
12087
+ }
12027
12088
12028
12089
if (hasNon16BitAccesses(PermMask, Op, OtherOp)) {
12029
12090
0 commit comments