Skip to content

Commit 49c5ceb

Browse files
committed
[X86] Improve support for vXi8 arithmetic shifts, logical left shifts
Use SWAR techniques for arithmetic shifts: we use the same technique as logical right shift but with an additional step of sign extending the result. Also, use the logical shift left technique even on AVX512 as vpmovzxbw and vpmovwb are actually quite expensive.
1 parent e99755d commit 49c5ceb

9 files changed

+568
-111
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 142 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -29830,6 +29830,144 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
2983029830
}
2983129831
}
2983229832

29833+
// Constant ISD::SRA/SRL/SHL can be performed efficiently on vXi8 vectors by
29834+
// using vXi16 vector operations.
29835+
if (ConstantAmt &&
29836+
(VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256()) ||
29837+
(VT == MVT::v64i8 && Subtarget.hasBWI())) &&
29838+
!Subtarget.hasXOP()) {
29839+
int NumElts = VT.getVectorNumElements();
29840+
MVT VT16 = MVT::getVectorVT(MVT::i16, NumElts / 2);
29841+
// We can do this extra fast if each pair of i8 elements is shifted by the
29842+
// same amount by doing this SWAR style: use a shift to move the valid bits
29843+
// to the right position, mask out any bits which crossed from one element
29844+
// to the other.
29845+
APInt UndefElts;
29846+
SmallVector<APInt, 64> AmtBits;
29847+
// This optimized lowering is only valid if the elements in a pair can
29848+
// be treated identically.
29849+
bool SameShifts = true;
29850+
SmallVector<APInt, 32> AmtBits16(NumElts / 2);
29851+
APInt UndefElts16 = APInt::getZero(AmtBits16.size());
29852+
if (getTargetConstantBitsFromNode(Amt, /*EltSizeInBits=*/8, UndefElts,
29853+
AmtBits, /*AllowWholeUndefs=*/true,
29854+
/*AllowPartialUndefs=*/false)) {
29855+
// Collect information to construct the BUILD_VECTOR for the i16 version
29856+
// of the shift. Conceptually, this is equivalent to:
29857+
// 1. Making sure the shift amounts are the same for both the low i8 and
29858+
// high i8 corresponding to the i16 lane.
29859+
// 2. Extending that shift amount to i16 for a build vector operation.
29860+
//
29861+
// We want to handle undef shift amounts which requires a little more
29862+
// logic (e.g. if one is undef and the other is not, grab the other shift
29863+
// amount).
29864+
for (unsigned SrcI = 0, E = AmtBits.size(); SrcI != E; SrcI += 2) {
29865+
unsigned DstI = SrcI / 2;
29866+
// Both elements are undef? Make a note and keep going.
29867+
if (UndefElts[SrcI] && UndefElts[SrcI + 1]) {
29868+
AmtBits16[DstI] = APInt::getZero(16);
29869+
UndefElts16.setBit(DstI);
29870+
continue;
29871+
}
29872+
// Even element is undef? We will shift it by the same shift amount as
29873+
// the odd element.
29874+
if (UndefElts[SrcI]) {
29875+
AmtBits16[DstI] = AmtBits[SrcI + 1].zext(16);
29876+
continue;
29877+
}
29878+
// Odd element is undef? We will shift it by the same shift amount as
29879+
// the even element.
29880+
if (UndefElts[SrcI + 1]) {
29881+
AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29882+
continue;
29883+
}
29884+
// Both elements are equal.
29885+
if (AmtBits[SrcI] == AmtBits[SrcI + 1]) {
29886+
AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29887+
continue;
29888+
}
29889+
// One of the provisional i16 elements will not have the same shift
29890+
// amount. Let's bail.
29891+
SameShifts = false;
29892+
break;
29893+
}
29894+
}
29895+
// We are only dealing with identical pairs.
29896+
if (SameShifts) {
29897+
// Cast the operand to vXi16.
29898+
SDValue R16 = DAG.getBitcast(VT16, R);
29899+
// Create our new vector of shift amounts.
29900+
SDValue Amt16 = getConstVector(AmtBits16, UndefElts16, VT16, DAG, dl);
29901+
// Perform the actual shift.
29902+
unsigned LogicalOpc = Opc == ISD::SRA ? ISD::SRL : Opc;
29903+
SDValue ShiftedR = DAG.getNode(LogicalOpc, dl, VT16, R16, Amt16);
29904+
// Now we need to construct a mask which will "drop" bits that get
29905+
// shifted past the LSB/MSB. For a logical shift left, it will look
29906+
// like:
29907+
// MaskLowBits = (0xff << Amt16) & 0xff;
29908+
// MaskHighBits = MaskLowBits << 8;
29909+
// Mask = MaskLowBits | MaskHighBits;
29910+
//
29911+
// This masking ensures that bits cannot migrate from one i8 to
29912+
// another. The construction of this mask will be constant folded.
29913+
// The mask for a logical right shift is nearly identical, the only
29914+
// difference is that 0xff is shifted right instead of left.
29915+
SDValue Cst255 = DAG.getConstant(0xff, dl, MVT::i16);
29916+
SDValue Splat255 = DAG.getSplat(VT16, dl, Cst255);
29917+
// The mask for the low bits is most simply expressed as an 8-bit
29918+
// field of all ones which is shifted in the exact same way the data
29919+
// is shifted but masked with 0xff.
29920+
SDValue MaskLowBits = DAG.getNode(LogicalOpc, dl, VT16, Splat255, Amt16);
29921+
MaskLowBits = DAG.getNode(ISD::AND, dl, VT16, MaskLowBits, Splat255);
29922+
SDValue Cst8 = DAG.getConstant(8, dl, MVT::i16);
29923+
SDValue Splat8 = DAG.getSplat(VT16, dl, Cst8);
29924+
// The mask for the high bits is the same as the mask for the low bits but
29925+
// shifted up by 8.
29926+
SDValue MaskHighBits =
29927+
DAG.getNode(ISD::SHL, dl, VT16, MaskLowBits, Splat8);
29928+
SDValue Mask = DAG.getNode(ISD::OR, dl, VT16, MaskLowBits, MaskHighBits);
29929+
// Finally, we mask the shifted vector with the SWAR mask.
29930+
SDValue Masked = DAG.getNode(ISD::AND, dl, VT16, ShiftedR, Mask);
29931+
Masked = DAG.getBitcast(VT, Masked);
29932+
if (Opc != ISD::SRA) {
29933+
// Logical shifts are complete at this point.
29934+
return Masked;
29935+
}
29936+
// At this point, we have done a *logical* shift right. We now need to
29937+
// sign extend the result so that we get behavior equivalent to an
29938+
// arithmetic shift right. Post-shifting by Amt16, our i8 elements are
29939+
// `8-Amt16` bits wide.
29940+
//
29941+
// To convert our `8-Amt16` bit unsigned numbers to 8-bit signed numbers,
29942+
// we need to replicate the bit at position `7-Amt16` into the MSBs of
29943+
// each i8.
29944+
// We can use the following trick to accomplish this:
29945+
// SignBitMask = 1 << (7-Amt16)
29946+
// (Masked ^ SignBitMask) - SignBitMask
29947+
//
29948+
// When the sign bit is already clear, this will compute:
29949+
// Masked + SignBitMask - SignBitMask
29950+
//
29951+
// This is equal to Masked which is what we want: the sign bit was clear
29952+
// so sign extending should be a no-op.
29953+
//
29954+
// When the sign bit is set, this will compute:
29955+
// Masked - SignBitmask - SignBitMask
29956+
//
29957+
// This is equal to Masked - 2*SignBitMask which will correctly sign
29958+
// extend our result.
29959+
SDValue CstHighBit = DAG.getConstant(0x80, dl, MVT::i8);
29960+
SDValue SplatHighBit = DAG.getSplat(VT, dl, CstHighBit);
29961+
// This does not induce recursion, all operands are constants.
29962+
SDValue SignBitMask = DAG.getNode(LogicalOpc, dl, VT, SplatHighBit, Amt);
29963+
SDValue FlippedSignBit =
29964+
DAG.getNode(ISD::XOR, dl, VT, Masked, SignBitMask);
29965+
SDValue Subtraction =
29966+
DAG.getNode(ISD::SUB, dl, VT, FlippedSignBit, SignBitMask);
29967+
return Subtraction;
29968+
}
29969+
}
29970+
2983329971
// If possible, lower this packed shift into a vector multiply instead of
2983429972
// expanding it into a sequence of scalar shifts.
2983529973
// For v32i8 cases, it might be quicker to split/extend to vXi16 shifts.
@@ -29950,103 +30088,18 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
2995030088
DAG.getNode(Opc, dl, ExtVT, R, Amt));
2995130089
}
2995230090

29953-
// Constant ISD::SRA/SRL can be performed efficiently on vXi8 vectors by using
29954-
// vXi16 vector operations.
30091+
// Constant ISD::SRA/SRL can be performed efficiently on vXi8 vectors as we
30092+
// extend to vXi16 to perform a MUL scale effectively as a MUL_LOHI.
2995530093
if (ConstantAmt && (Opc == ISD::SRA || Opc == ISD::SRL) &&
2995630094
(VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256()) ||
2995730095
(VT == MVT::v64i8 && Subtarget.hasBWI())) &&
2995830096
!Subtarget.hasXOP()) {
2995930097
int NumElts = VT.getVectorNumElements();
2996030098
MVT VT16 = MVT::getVectorVT(MVT::i16, NumElts / 2);
29961-
// We can do this extra fast if each pair of i8 elements is shifted by the
29962-
// same amount by doing this SWAR style: use a shift to move the valid bits
29963-
// to the right position, mask out any bits which crossed from one element
29964-
// to the other.
29965-
if (Opc == ISD::SRL || Opc == ISD::SHL) {
29966-
APInt UndefElts;
29967-
SmallVector<APInt, 64> AmtBits;
29968-
if (getTargetConstantBitsFromNode(Amt, /*EltSizeInBits=*/8, UndefElts,
29969-
AmtBits, /*AllowWholeUndefs=*/true,
29970-
/*AllowPartialUndefs=*/false)) {
29971-
// This optimized lowering is only valid if the elements in a pair can
29972-
// be treated identically.
29973-
bool SameShifts = true;
29974-
SmallVector<APInt, 32> AmtBits16(NumElts / 2);
29975-
APInt UndefElts16 = APInt::getZero(AmtBits16.size());
29976-
for (unsigned SrcI = 0, E = AmtBits.size(); SrcI != E; SrcI += 2) {
29977-
unsigned DstI = SrcI / 2;
29978-
// Both elements are undef? Make a note and keep going.
29979-
if (UndefElts[SrcI] && UndefElts[SrcI + 1]) {
29980-
AmtBits16[DstI] = APInt::getZero(16);
29981-
UndefElts16.setBit(DstI);
29982-
continue;
29983-
}
29984-
// Even element is undef? We will shift it by the same shift amount as
29985-
// the odd element.
29986-
if (UndefElts[SrcI]) {
29987-
AmtBits16[DstI] = AmtBits[SrcI + 1].zext(16);
29988-
continue;
29989-
}
29990-
// Odd element is undef? We will shift it by the same shift amount as
29991-
// the even element.
29992-
if (UndefElts[SrcI + 1]) {
29993-
AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29994-
continue;
29995-
}
29996-
// Both elements are equal.
29997-
if (AmtBits[SrcI] == AmtBits[SrcI + 1]) {
29998-
AmtBits16[DstI] = AmtBits[SrcI].zext(16);
29999-
continue;
30000-
}
30001-
// One of the provisional i16 elements will not have the same shift
30002-
// amount. Let's bail.
30003-
SameShifts = false;
30004-
break;
30005-
}
30006-
30007-
// We are only dealing with identical pairs and the operation is a
30008-
// logical shift.
30009-
if (SameShifts) {
30010-
// Cast the operand to vXi16.
30011-
SDValue R16 = DAG.getBitcast(VT16, R);
30012-
// Create our new vector of shift amounts.
30013-
SDValue Amt16 = getConstVector(AmtBits16, UndefElts16, VT16, DAG, dl);
30014-
// Perform the actual shift.
30015-
SDValue ShiftedR = DAG.getNode(Opc, dl, VT16, R16, Amt16);
30016-
// Now we need to construct a mask which will "drop" bits that get
30017-
// shifted past the LSB/MSB. For a logical shift left, it will look
30018-
// like:
30019-
// MaskLowBits = (0xff << Amt16) & 0xff;
30020-
// MaskHighBits = MaskLowBits << 8;
30021-
// Mask = MaskLowBits | MaskHighBits;
30022-
//
30023-
// This masking ensures that bits cannot migrate from one i8 to
30024-
// another. The construction of this mask will be constant folded.
30025-
// The mask for a logical right shift is nearly identical, the only
30026-
// difference is that 0xff is shifted right instead of left.
30027-
SDValue Cst255 = DAG.getConstant(0xff, dl, MVT::i16);
30028-
SDValue Splat255 = DAG.getSplat(VT16, dl, Cst255);
30029-
// The mask for the low bits is most simply expressed as an 8-bit
30030-
// field of all ones which is shifted in the exact same way the data
30031-
// is shifted but masked with 0xff.
30032-
SDValue MaskLowBits = DAG.getNode(Opc, dl, VT16, Splat255, Amt16);
30033-
MaskLowBits = DAG.getNode(ISD::AND, dl, VT16, MaskLowBits, Splat255);
30034-
SDValue Cst8 = DAG.getConstant(8, dl, MVT::i16);
30035-
SDValue Splat8 = DAG.getSplat(VT16, dl, Cst8);
30036-
// Thie mask for the high bits is the same as the mask for the low
30037-
// bits but shifted up by 8.
30038-
SDValue MaskHighBits = DAG.getNode(ISD::SHL, dl, VT16, MaskLowBits, Splat8);
30039-
SDValue Mask = DAG.getNode(ISD::OR, dl, VT16, MaskLowBits, MaskHighBits);
30040-
// Finally, we mask the shifted vector with the SWAR mask.
30041-
SDValue Masked = DAG.getNode(ISD::AND, dl, VT16, ShiftedR, Mask);
30042-
return DAG.getBitcast(VT, Masked);
30043-
}
30044-
}
30045-
}
3004630099
SDValue Cst8 = DAG.getTargetConstant(8, dl, MVT::i8);
3004730100

30048-
// Extend to vXi16 to perform a MUL scale effectively as a MUL_LOHI (it
30049-
// doesn't matter if the type isn't legal).
30101+
// Extend constant shift amount to vXi16 (it doesn't matter if the type
30102+
// isn't legal).
3005030103
MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts);
3005130104
Amt = DAG.getZExtOrTrunc(Amt, dl, ExVT);
3005230105
Amt = DAG.getNode(ISD::SUB, dl, ExVT, DAG.getConstant(8, dl, ExVT), Amt);

llvm/test/CodeGen/X86/vector-shift-ashr-128.ll

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,6 +1586,99 @@ define <16 x i8> @constant_shift_v16i8(<16 x i8> %a) nounwind {
15861586
ret <16 x i8> %shift
15871587
}
15881588

1589+
define <16 x i8> @constant_shift_v16i8_pairs(<16 x i8> %a) nounwind {
1590+
; SSE2-LABEL: constant_shift_v16i8_pairs:
1591+
; SSE2: # %bb.0:
1592+
; SSE2-NEXT: movdqa {{.*#+}} xmm1 = [65535,65535,65535,65535,65535,0,65535,65535]
1593+
; SSE2-NEXT: pandn %xmm0, %xmm1
1594+
; SSE2-NEXT: pmulhuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
1595+
; SSE2-NEXT: por %xmm1, %xmm0
1596+
; SSE2-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
1597+
; SSE2-NEXT: movdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
1598+
; SSE2-NEXT: pxor %xmm1, %xmm0
1599+
; SSE2-NEXT: psubb %xmm1, %xmm0
1600+
; SSE2-NEXT: retq
1601+
;
1602+
; SSE41-LABEL: constant_shift_v16i8_pairs:
1603+
; SSE41: # %bb.0:
1604+
; SSE41-NEXT: movdqa {{.*#+}} xmm1 = [32768,4096,512,8192,16384,u,2048,1024]
1605+
; SSE41-NEXT: pmulhuw %xmm0, %xmm1
1606+
; SSE41-NEXT: pblendw {{.*#+}} xmm0 = xmm1[0,1,2,3,4],xmm0[5],xmm1[6,7]
1607+
; SSE41-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
1608+
; SSE41-NEXT: movdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
1609+
; SSE41-NEXT: pxor %xmm1, %xmm0
1610+
; SSE41-NEXT: psubb %xmm1, %xmm0
1611+
; SSE41-NEXT: retq
1612+
;
1613+
; AVX-LABEL: constant_shift_v16i8_pairs:
1614+
; AVX: # %bb.0:
1615+
; AVX-NEXT: vpmulhuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1 # [32768,4096,512,8192,16384,u,2048,1024]
1616+
; AVX-NEXT: vpblendw {{.*#+}} xmm0 = xmm1[0,1,2,3,4],xmm0[5],xmm1[6,7]
1617+
; AVX-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1618+
; AVX-NEXT: vmovdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
1619+
; AVX-NEXT: vpxor %xmm1, %xmm0, %xmm0
1620+
; AVX-NEXT: vpsubb %xmm1, %xmm0, %xmm0
1621+
; AVX-NEXT: retq
1622+
;
1623+
; XOP-LABEL: constant_shift_v16i8_pairs:
1624+
; XOP: # %bb.0:
1625+
; XOP-NEXT: vpshab {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1626+
; XOP-NEXT: retq
1627+
;
1628+
; AVX512DQ-LABEL: constant_shift_v16i8_pairs:
1629+
; AVX512DQ: # %bb.0:
1630+
; AVX512DQ-NEXT: vpmulhuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1 # [32768,4096,512,8192,16384,u,2048,1024]
1631+
; AVX512DQ-NEXT: vpblendw {{.*#+}} xmm0 = xmm1[0,1,2,3,4],xmm0[5],xmm1[6,7]
1632+
; AVX512DQ-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1633+
; AVX512DQ-NEXT: vmovdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
1634+
; AVX512DQ-NEXT: vpxor %xmm1, %xmm0, %xmm0
1635+
; AVX512DQ-NEXT: vpsubb %xmm1, %xmm0, %xmm0
1636+
; AVX512DQ-NEXT: retq
1637+
;
1638+
; AVX512BW-LABEL: constant_shift_v16i8_pairs:
1639+
; AVX512BW: # %bb.0:
1640+
; AVX512BW-NEXT: # kill: def $xmm0 killed $xmm0 def $zmm0
1641+
; AVX512BW-NEXT: vpmovsxbw {{.*#+}} xmm1 = [1,4,7,3,2,0,5,6]
1642+
; AVX512BW-NEXT: vpsrlvw %zmm1, %zmm0, %zmm0
1643+
; AVX512BW-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1644+
; AVX512BW-NEXT: vmovdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
1645+
; AVX512BW-NEXT: vpxor %xmm1, %xmm0, %xmm0
1646+
; AVX512BW-NEXT: vpsubb %xmm1, %xmm0, %xmm0
1647+
; AVX512BW-NEXT: vzeroupper
1648+
; AVX512BW-NEXT: retq
1649+
;
1650+
; AVX512DQVL-LABEL: constant_shift_v16i8_pairs:
1651+
; AVX512DQVL: # %bb.0:
1652+
; AVX512DQVL-NEXT: vpmulhuw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1 # [32768,4096,512,8192,16384,u,2048,1024]
1653+
; AVX512DQVL-NEXT: vpblendw {{.*#+}} xmm0 = xmm1[0,1,2,3,4],xmm0[5],xmm1[6,7]
1654+
; AVX512DQVL-NEXT: vmovdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
1655+
; AVX512DQVL-NEXT: vpternlogq $108, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
1656+
; AVX512DQVL-NEXT: vpsubb %xmm1, %xmm0, %xmm0
1657+
; AVX512DQVL-NEXT: retq
1658+
;
1659+
; AVX512BWVL-LABEL: constant_shift_v16i8_pairs:
1660+
; AVX512BWVL: # %bb.0:
1661+
; AVX512BWVL-NEXT: vpsrlvw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
1662+
; AVX512BWVL-NEXT: vmovdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
1663+
; AVX512BWVL-NEXT: vpternlogq $108, {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
1664+
; AVX512BWVL-NEXT: vpsubb %xmm1, %xmm0, %xmm0
1665+
; AVX512BWVL-NEXT: retq
1666+
;
1667+
; X86-SSE-LABEL: constant_shift_v16i8_pairs:
1668+
; X86-SSE: # %bb.0:
1669+
; X86-SSE-NEXT: movdqa {{.*#+}} xmm1 = [65535,65535,65535,65535,65535,0,65535,65535]
1670+
; X86-SSE-NEXT: pandn %xmm0, %xmm1
1671+
; X86-SSE-NEXT: pmulhuw {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0
1672+
; X86-SSE-NEXT: por %xmm1, %xmm0
1673+
; X86-SSE-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0
1674+
; X86-SSE-NEXT: movdqa {{.*#+}} xmm1 = [64,64,8,8,1,1,16,16,32,32,128,128,4,4,2,2]
1675+
; X86-SSE-NEXT: pxor %xmm1, %xmm0
1676+
; X86-SSE-NEXT: psubb %xmm1, %xmm0
1677+
; X86-SSE-NEXT: retl
1678+
%shift = ashr <16 x i8> %a, <i8 1, i8 1, i8 4, i8 4, i8 7, i8 7, i8 3, i8 3, i8 2, i8 2, i8 0, i8 0, i8 5, i8 5, i8 6, i8 6>
1679+
ret <16 x i8> %shift
1680+
}
1681+
15891682
;
15901683
; Uniform Constant Shifts
15911684
;

0 commit comments

Comments
 (0)