Skip to content

Commit b320d37

Browse files
committed
[X86] Add handling for select(icmp_uge(amt,BW),0,shift_logical(x,amt)) -> avx2 shift(x,amt)
We need to catch this otherwise pre-AVX512 targets will fold this to and(icmp_ult(amt,BW),shift_logical(x,amt))
1 parent deff3af commit b320d37

File tree

3 files changed

+24
-30
lines changed

3 files changed

+24
-30
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46190,20 +46190,32 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
4619046190
// to bitwidth-1 for unsigned shifts, effectively performing a maximum left
4619146191
// shift of bitwidth-1 positions. and returns zero for unsigned right shifts
4619246192
// exceeding bitwidth-1.
46193-
if (N->getOpcode() == ISD::VSELECT &&
46194-
(LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SHL) &&
46195-
supportedVectorVarShift(VT, Subtarget, LHS.getOpcode())) {
46193+
if (N->getOpcode() == ISD::VSELECT) {
4619646194
using namespace llvm::SDPatternMatch;
4619746195
// fold select(icmp_ult(amt,BW),shl(x,amt),0) -> avx2 psllv(x,amt)
4619846196
// fold select(icmp_ult(amt,BW),srl(x,amt),0) -> avx2 psrlv(x,amt)
46199-
if (ISD::isConstantSplatVectorAllZeros(RHS.getNode()) &&
46197+
if ((LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SHL) &&
46198+
supportedVectorVarShift(VT, Subtarget, LHS.getOpcode()) &&
46199+
ISD::isConstantSplatVectorAllZeros(RHS.getNode()) &&
4620046200
sd_match(Cond, m_SetCC(m_Specific(LHS.getOperand(1)),
4620146201
m_SpecificInt(VT.getScalarSizeInBits()),
4620246202
m_SpecificCondCode(ISD::SETULT)))) {
4620346203
return DAG.getNode(LHS.getOpcode() == ISD::SRL ? X86ISD::VSRLV
4620446204
: X86ISD::VSHLV,
4620546205
DL, VT, LHS.getOperand(0), LHS.getOperand(1));
4620646206
}
46207+
// fold select(icmp_uge(amt,BW),0,shl(x,amt)) -> avx2 psllv(x,amt)
46208+
// fold select(icmp_uge(amt,BW),0,srl(x,amt)) -> avx2 psrlv(x,amt)
46209+
if ((RHS.getOpcode() == ISD::SRL || RHS.getOpcode() == ISD::SHL) &&
46210+
supportedVectorVarShift(VT, Subtarget, RHS.getOpcode()) &&
46211+
ISD::isConstantSplatVectorAllZeros(LHS.getNode()) &&
46212+
sd_match(Cond, m_SetCC(m_Specific(RHS.getOperand(1)),
46213+
m_SpecificInt(VT.getScalarSizeInBits()),
46214+
m_SpecificCondCode(ISD::SETUGE)))) {
46215+
return DAG.getNode(RHS.getOpcode() == ISD::SRL ? X86ISD::VSRLV
46216+
: X86ISD::VSHLV,
46217+
DL, VT, RHS.getOperand(0), RHS.getOperand(1));
46218+
}
4620746219
}
4620846220

4620946221
// Early exit check

llvm/test/CodeGen/X86/combine-shl.ll

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,19 +1086,10 @@ define <4 x i32> @combine_vec_shl_commuted_clamped1(<4 x i32> %sh, <4 x i32> %am
10861086
; SSE41-NEXT: pand %xmm2, %xmm0
10871087
; SSE41-NEXT: retq
10881088
;
1089-
; AVX2-LABEL: combine_vec_shl_commuted_clamped1:
1090-
; AVX2: # %bb.0:
1091-
; AVX2-NEXT: vpbroadcastd {{.*#+}} xmm2 = [31,31,31,31]
1092-
; AVX2-NEXT: vpsllvd %xmm1, %xmm0, %xmm0
1093-
; AVX2-NEXT: vpminud %xmm2, %xmm1, %xmm2
1094-
; AVX2-NEXT: vpcmpeqd %xmm2, %xmm1, %xmm1
1095-
; AVX2-NEXT: vpand %xmm0, %xmm1, %xmm0
1096-
; AVX2-NEXT: retq
1097-
;
1098-
; AVX512-LABEL: combine_vec_shl_commuted_clamped1:
1099-
; AVX512: # %bb.0:
1100-
; AVX512-NEXT: vpsllvd %xmm1, %xmm0, %xmm0
1101-
; AVX512-NEXT: retq
1089+
; AVX-LABEL: combine_vec_shl_commuted_clamped1:
1090+
; AVX: # %bb.0:
1091+
; AVX-NEXT: vpsllvd %xmm1, %xmm0, %xmm0
1092+
; AVX-NEXT: retq
11021093
%cmp.i = icmp uge <4 x i32> %amt, <i32 32, i32 32, i32 32, i32 32>
11031094
%shl = shl <4 x i32> %sh, %amt
11041095
%1 = select <4 x i1> %cmp.i, <4 x i32> zeroinitializer, <4 x i32> %shl

llvm/test/CodeGen/X86/combine-srl.ll

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -828,19 +828,10 @@ define <4 x i32> @combine_vec_lshr_commuted_clamped1(<4 x i32> %sh, <4 x i32> %a
828828
; SSE41-NEXT: pand %xmm2, %xmm0
829829
; SSE41-NEXT: retq
830830
;
831-
; AVX2-LABEL: combine_vec_lshr_commuted_clamped1:
832-
; AVX2: # %bb.0:
833-
; AVX2-NEXT: vpbroadcastd {{.*#+}} xmm2 = [31,31,31,31]
834-
; AVX2-NEXT: vpsrlvd %xmm1, %xmm0, %xmm0
835-
; AVX2-NEXT: vpminud %xmm2, %xmm1, %xmm2
836-
; AVX2-NEXT: vpcmpeqd %xmm2, %xmm1, %xmm1
837-
; AVX2-NEXT: vpand %xmm0, %xmm1, %xmm0
838-
; AVX2-NEXT: retq
839-
;
840-
; AVX512-LABEL: combine_vec_lshr_commuted_clamped1:
841-
; AVX512: # %bb.0:
842-
; AVX512-NEXT: vpsrlvd %xmm1, %xmm0, %xmm0
843-
; AVX512-NEXT: retq
831+
; AVX-LABEL: combine_vec_lshr_commuted_clamped1:
832+
; AVX: # %bb.0:
833+
; AVX-NEXT: vpsrlvd %xmm1, %xmm0, %xmm0
834+
; AVX-NEXT: retq
844835
%cmp.i = icmp uge <4 x i32> %amt, <i32 32, i32 32, i32 32, i32 32>
845836
%shr = lshr <4 x i32> %sh, %amt
846837
%1 = select <4 x i1> %cmp.i, <4 x i32> zeroinitializer, <4 x i32> %shr

0 commit comments

Comments
 (0)