@@ -29830,6 +29830,144 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
29830
29830
}
29831
29831
}
29832
29832
29833
+ // Constant ISD::SRA/SRL/SHL can be performed efficiently on vXi8 vectors by
29834
+ // using vXi16 vector operations.
29835
+ if (ConstantAmt &&
29836
+ (VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256()) ||
29837
+ (VT == MVT::v64i8 && Subtarget.hasBWI())) &&
29838
+ !Subtarget.hasXOP()) {
29839
+ int NumElts = VT.getVectorNumElements();
29840
+ MVT VT16 = MVT::getVectorVT(MVT::i16, NumElts / 2);
29841
+ // We can do this extra fast if each pair of i8 elements is shifted by the
29842
+ // same amount by doing this SWAR style: use a shift to move the valid bits
29843
+ // to the right position, mask out any bits which crossed from one element
29844
+ // to the other.
29845
+ APInt UndefElts;
29846
+ SmallVector<APInt, 64> AmtBits;
29847
+ // This optimized lowering is only valid if the elements in a pair can
29848
+ // be treated identically.
29849
+ bool SameShifts = true;
29850
+ SmallVector<APInt, 32> AmtBits16(NumElts / 2);
29851
+ APInt UndefElts16 = APInt::getZero(AmtBits16.size());
29852
+ if (getTargetConstantBitsFromNode(Amt, /*EltSizeInBits=*/8, UndefElts,
29853
+ AmtBits, /*AllowWholeUndefs=*/true,
29854
+ /*AllowPartialUndefs=*/false)) {
29855
+ // Collect information to construct the BUILD_VECTOR for the i16 version
29856
+ // of the shift. Conceptually, this is equivalent to:
29857
+ // 1. Making sure the shift amounts are the same for both the low i8 and
29858
+ // high i8 corresponding to the i16 lane.
29859
+ // 2. Extending that shift amount to i16 for a build vector operation.
29860
+ //
29861
+ // We want to handle undef shift amounts which requires a little more
29862
+ // logic (e.g. if one is undef and the other is not, grab the other shift
29863
+ // amount).
29864
+ for (unsigned SrcI = 0, E = AmtBits.size(); SrcI != E; SrcI += 2) {
29865
+ unsigned DstI = SrcI / 2;
29866
+ // Both elements are undef? Make a note and keep going.
29867
+ if (UndefElts[SrcI] && UndefElts[SrcI + 1]) {
29868
+ AmtBits16[DstI] = APInt::getZero(16);
29869
+ UndefElts16.setBit(DstI);
29870
+ continue;
29871
+ }
29872
+ // Even element is undef? We will shift it by the same shift amount as
29873
+ // the odd element.
29874
+ if (UndefElts[SrcI]) {
29875
+ AmtBits16[DstI] = AmtBits[SrcI + 1].zext(16);
29876
+ continue;
29877
+ }
29878
+ // Odd element is undef? We will shift it by the same shift amount as
29879
+ // the even element.
29880
+ if (UndefElts[SrcI + 1]) {
29881
+ AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29882
+ continue;
29883
+ }
29884
+ // Both elements are equal.
29885
+ if (AmtBits[SrcI] == AmtBits[SrcI + 1]) {
29886
+ AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29887
+ continue;
29888
+ }
29889
+ // One of the provisional i16 elements will not have the same shift
29890
+ // amount. Let's bail.
29891
+ SameShifts = false;
29892
+ break;
29893
+ }
29894
+ }
29895
+ // We are only dealing with identical pairs.
29896
+ if (SameShifts) {
29897
+ // Cast the operand to vXi16.
29898
+ SDValue R16 = DAG.getBitcast(VT16, R);
29899
+ // Create our new vector of shift amounts.
29900
+ SDValue Amt16 = getConstVector(AmtBits16, UndefElts16, VT16, DAG, dl);
29901
+ // Perform the actual shift.
29902
+ unsigned LogicalOpc = Opc == ISD::SRA ? ISD::SRL : Opc;
29903
+ SDValue ShiftedR = DAG.getNode(LogicalOpc, dl, VT16, R16, Amt16);
29904
+ // Now we need to construct a mask which will "drop" bits that get
29905
+ // shifted past the LSB/MSB. For a logical shift left, it will look
29906
+ // like:
29907
+ // MaskLowBits = (0xff << Amt16) & 0xff;
29908
+ // MaskHighBits = MaskLowBits << 8;
29909
+ // Mask = MaskLowBits | MaskHighBits;
29910
+ //
29911
+ // This masking ensures that bits cannot migrate from one i8 to
29912
+ // another. The construction of this mask will be constant folded.
29913
+ // The mask for a logical right shift is nearly identical, the only
29914
+ // difference is that 0xff is shifted right instead of left.
29915
+ SDValue Cst255 = DAG.getConstant(0xff, dl, MVT::i16);
29916
+ SDValue Splat255 = DAG.getSplat(VT16, dl, Cst255);
29917
+ // The mask for the low bits is most simply expressed as an 8-bit
29918
+ // field of all ones which is shifted in the exact same way the data
29919
+ // is shifted but masked with 0xff.
29920
+ SDValue MaskLowBits = DAG.getNode(LogicalOpc, dl, VT16, Splat255, Amt16);
29921
+ MaskLowBits = DAG.getNode(ISD::AND, dl, VT16, MaskLowBits, Splat255);
29922
+ SDValue Cst8 = DAG.getConstant(8, dl, MVT::i16);
29923
+ SDValue Splat8 = DAG.getSplat(VT16, dl, Cst8);
29924
+ // The mask for the high bits is the same as the mask for the low bits but
29925
+ // shifted up by 8.
29926
+ SDValue MaskHighBits =
29927
+ DAG.getNode(ISD::SHL, dl, VT16, MaskLowBits, Splat8);
29928
+ SDValue Mask = DAG.getNode(ISD::OR, dl, VT16, MaskLowBits, MaskHighBits);
29929
+ // Finally, we mask the shifted vector with the SWAR mask.
29930
+ SDValue Masked = DAG.getNode(ISD::AND, dl, VT16, ShiftedR, Mask);
29931
+ Masked = DAG.getBitcast(VT, Masked);
29932
+ if (Opc != ISD::SRA) {
29933
+ // Logical shifts are complete at this point.
29934
+ return Masked;
29935
+ }
29936
+ // At this point, we have done a *logical* shift right. We now need to
29937
+ // sign extend the result so that we get behavior equivalent to an
29938
+ // arithmetic shift right. Post-shifting by Amt16, our i8 elements are
29939
+ // `8-Amt16` bits wide.
29940
+ //
29941
+ // To convert our `8-Amt16` bit unsigned numbers to 8-bit signed numbers,
29942
+ // we need to replicate the bit at position `7-Amt16` into the MSBs of
29943
+ // each i8.
29944
+ // We can use the following trick to accomplish this:
29945
+ // SignBitMask = 1 << (7-Amt16)
29946
+ // (Masked ^ SignBitMask) - SignBitMask
29947
+ //
29948
+ // When the sign bit is already clear, this will compute:
29949
+ // Masked + SignBitMask - SignBitMask
29950
+ //
29951
+ // This is equal to Masked which is what we want: the sign bit was clear
29952
+ // so sign extending should be a no-op.
29953
+ //
29954
+ // When the sign bit is set, this will compute:
29955
+ // Masked - SignBitmask - SignBitMask
29956
+ //
29957
+ // This is equal to Masked - 2*SignBitMask which will correctly sign
29958
+ // extend our result.
29959
+ SDValue CstHighBit = DAG.getConstant(0x80, dl, MVT::i8);
29960
+ SDValue SplatHighBit = DAG.getSplat(VT, dl, CstHighBit);
29961
+ // This does not induce recursion, all operands are constants.
29962
+ SDValue SignBitMask = DAG.getNode(LogicalOpc, dl, VT, SplatHighBit, Amt);
29963
+ SDValue FlippedSignBit =
29964
+ DAG.getNode(ISD::XOR, dl, VT, Masked, SignBitMask);
29965
+ SDValue Subtraction =
29966
+ DAG.getNode(ISD::SUB, dl, VT, FlippedSignBit, SignBitMask);
29967
+ return Subtraction;
29968
+ }
29969
+ }
29970
+
29833
29971
// If possible, lower this packed shift into a vector multiply instead of
29834
29972
// expanding it into a sequence of scalar shifts.
29835
29973
// For v32i8 cases, it might be quicker to split/extend to vXi16 shifts.
@@ -29950,103 +30088,18 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
29950
30088
DAG.getNode(Opc, dl, ExtVT, R, Amt));
29951
30089
}
29952
30090
29953
- // Constant ISD::SRA/SRL can be performed efficiently on vXi8 vectors by using
29954
- // vXi16 vector operations .
30091
+ // Constant ISD::SRA/SRL can be performed efficiently on vXi8 vectors as we
30092
+ // extend to vXi16 to perform a MUL scale effectively as a MUL_LOHI .
29955
30093
if (ConstantAmt && (Opc == ISD::SRA || Opc == ISD::SRL) &&
29956
30094
(VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256()) ||
29957
30095
(VT == MVT::v64i8 && Subtarget.hasBWI())) &&
29958
30096
!Subtarget.hasXOP()) {
29959
30097
int NumElts = VT.getVectorNumElements();
29960
30098
MVT VT16 = MVT::getVectorVT(MVT::i16, NumElts / 2);
29961
- // We can do this extra fast if each pair of i8 elements is shifted by the
29962
- // same amount by doing this SWAR style: use a shift to move the valid bits
29963
- // to the right position, mask out any bits which crossed from one element
29964
- // to the other.
29965
- if (Opc == ISD::SRL || Opc == ISD::SHL) {
29966
- APInt UndefElts;
29967
- SmallVector<APInt, 64> AmtBits;
29968
- if (getTargetConstantBitsFromNode(Amt, /*EltSizeInBits=*/8, UndefElts,
29969
- AmtBits, /*AllowWholeUndefs=*/true,
29970
- /*AllowPartialUndefs=*/false)) {
29971
- // This optimized lowering is only valid if the elements in a pair can
29972
- // be treated identically.
29973
- bool SameShifts = true;
29974
- SmallVector<APInt, 32> AmtBits16(NumElts / 2);
29975
- APInt UndefElts16 = APInt::getZero(AmtBits16.size());
29976
- for (unsigned SrcI = 0, E = AmtBits.size(); SrcI != E; SrcI += 2) {
29977
- unsigned DstI = SrcI / 2;
29978
- // Both elements are undef? Make a note and keep going.
29979
- if (UndefElts[SrcI] && UndefElts[SrcI + 1]) {
29980
- AmtBits16[DstI] = APInt::getZero(16);
29981
- UndefElts16.setBit(DstI);
29982
- continue;
29983
- }
29984
- // Even element is undef? We will shift it by the same shift amount as
29985
- // the odd element.
29986
- if (UndefElts[SrcI]) {
29987
- AmtBits16[DstI] = AmtBits[SrcI + 1].zext(16);
29988
- continue;
29989
- }
29990
- // Odd element is undef? We will shift it by the same shift amount as
29991
- // the even element.
29992
- if (UndefElts[SrcI + 1]) {
29993
- AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29994
- continue;
29995
- }
29996
- // Both elements are equal.
29997
- if (AmtBits[SrcI] == AmtBits[SrcI + 1]) {
29998
- AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29999
- continue;
30000
- }
30001
- // One of the provisional i16 elements will not have the same shift
30002
- // amount. Let's bail.
30003
- SameShifts = false;
30004
- break;
30005
- }
30006
-
30007
- // We are only dealing with identical pairs and the operation is a
30008
- // logical shift.
30009
- if (SameShifts) {
30010
- // Cast the operand to vXi16.
30011
- SDValue R16 = DAG.getBitcast(VT16, R);
30012
- // Create our new vector of shift amounts.
30013
- SDValue Amt16 = getConstVector(AmtBits16, UndefElts16, VT16, DAG, dl);
30014
- // Perform the actual shift.
30015
- SDValue ShiftedR = DAG.getNode(Opc, dl, VT16, R16, Amt16);
30016
- // Now we need to construct a mask which will "drop" bits that get
30017
- // shifted past the LSB/MSB. For a logical shift left, it will look
30018
- // like:
30019
- // MaskLowBits = (0xff << Amt16) & 0xff;
30020
- // MaskHighBits = MaskLowBits << 8;
30021
- // Mask = MaskLowBits | MaskHighBits;
30022
- //
30023
- // This masking ensures that bits cannot migrate from one i8 to
30024
- // another. The construction of this mask will be constant folded.
30025
- // The mask for a logical right shift is nearly identical, the only
30026
- // difference is that 0xff is shifted right instead of left.
30027
- SDValue Cst255 = DAG.getConstant(0xff, dl, MVT::i16);
30028
- SDValue Splat255 = DAG.getSplat(VT16, dl, Cst255);
30029
- // The mask for the low bits is most simply expressed as an 8-bit
30030
- // field of all ones which is shifted in the exact same way the data
30031
- // is shifted but masked with 0xff.
30032
- SDValue MaskLowBits = DAG.getNode(Opc, dl, VT16, Splat255, Amt16);
30033
- MaskLowBits = DAG.getNode(ISD::AND, dl, VT16, MaskLowBits, Splat255);
30034
- SDValue Cst8 = DAG.getConstant(8, dl, MVT::i16);
30035
- SDValue Splat8 = DAG.getSplat(VT16, dl, Cst8);
30036
- // Thie mask for the high bits is the same as the mask for the low
30037
- // bits but shifted up by 8.
30038
- SDValue MaskHighBits = DAG.getNode(ISD::SHL, dl, VT16, MaskLowBits, Splat8);
30039
- SDValue Mask = DAG.getNode(ISD::OR, dl, VT16, MaskLowBits, MaskHighBits);
30040
- // Finally, we mask the shifted vector with the SWAR mask.
30041
- SDValue Masked = DAG.getNode(ISD::AND, dl, VT16, ShiftedR, Mask);
30042
- return DAG.getBitcast(VT, Masked);
30043
- }
30044
- }
30045
- }
30046
30099
SDValue Cst8 = DAG.getTargetConstant(8, dl, MVT::i8);
30047
30100
30048
- // Extend to vXi16 to perform a MUL scale effectively as a MUL_LOHI (it
30049
- // doesn't matter if the type isn't legal).
30101
+ // Extend constant shift amount to vXi16 (it doesn't matter if the type
30102
+ // isn't legal).
30050
30103
MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts);
30051
30104
Amt = DAG.getZExtOrTrunc(Amt, dl, ExVT);
30052
30105
Amt = DAG.getNode(ISD::SUB, dl, ExVT, DAG.getConstant(8, dl, ExVT), Amt);
0 commit comments