Skip to content

Commit 703c065

Browse files
LU-JOHNtomtor
authored andcommitted
[AMDGPU] Convert more 64-bit lshr to 32-bit if shift amt>=32 (llvm#138204)
Convert vector 64-bit lshr to 32-bit if shift amt is known to be >= 32. Also convert scalar 64-bit lshr to 32-bit if shift amt is variable but known to be >=32. --------- Signed-off-by: John Lu <[email protected]>
1 parent 8973ba9 commit 703c065

File tree

3 files changed

+196
-161
lines changed

3 files changed

+196
-161
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 91 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4097,7 +4097,7 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
40974097
if (VT.getScalarType() != MVT::i64)
40984098
return SDValue();
40994099

4100-
// i64 (shl x, C) -> (build_pair 0, (shl x, C -32))
4100+
// i64 (shl x, C) -> (build_pair 0, (shl x, C - 32))
41014101

41024102
// On some subtargets, 64-bit shift is a quarter rate instruction. In the
41034103
// common case, splitting this into a move and a 32-bit shift is faster and
@@ -4117,12 +4117,12 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
41174117
ShiftAmt = DAG.getConstant(RHSVal - TargetScalarType.getSizeInBits(), SL,
41184118
TargetType);
41194119
} else {
4120-
SDValue truncShiftAmt = DAG.getNode(ISD::TRUNCATE, SL, TargetType, RHS);
4120+
SDValue TruncShiftAmt = DAG.getNode(ISD::TRUNCATE, SL, TargetType, RHS);
41214121
const SDValue ShiftMask =
41224122
DAG.getConstant(TargetScalarType.getSizeInBits() - 1, SL, TargetType);
41234123
// This AND instruction will clamp out of bounds shift values.
41244124
// It will also be removed during later instruction selection.
4125-
ShiftAmt = DAG.getNode(ISD::AND, SL, TargetType, truncShiftAmt, ShiftMask);
4125+
ShiftAmt = DAG.getNode(ISD::AND, SL, TargetType, TruncShiftAmt, ShiftMask);
41264126
}
41274127

41284128
SDValue Lo = DAG.getNode(ISD::TRUNCATE, SL, TargetType, LHS);
@@ -4181,50 +4181,105 @@ SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,
41814181

41824182
SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
41834183
DAGCombinerInfo &DCI) const {
4184-
auto *RHS = dyn_cast<ConstantSDNode>(N->getOperand(1));
4185-
if (!RHS)
4186-
return SDValue();
4187-
4184+
SDValue RHS = N->getOperand(1);
4185+
ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
41884186
EVT VT = N->getValueType(0);
41894187
SDValue LHS = N->getOperand(0);
4190-
unsigned ShiftAmt = RHS->getZExtValue();
41914188
SelectionDAG &DAG = DCI.DAG;
41924189
SDLoc SL(N);
4190+
unsigned RHSVal;
4191+
4192+
if (CRHS) {
4193+
RHSVal = CRHS->getZExtValue();
41934194

4194-
// fold (srl (and x, c1 << c2), c2) -> (and (srl(x, c2), c1)
4195-
// this improves the ability to match BFE patterns in isel.
4196-
if (LHS.getOpcode() == ISD::AND) {
4197-
if (auto *Mask = dyn_cast<ConstantSDNode>(LHS.getOperand(1))) {
4198-
unsigned MaskIdx, MaskLen;
4199-
if (Mask->getAPIntValue().isShiftedMask(MaskIdx, MaskLen) &&
4200-
MaskIdx == ShiftAmt) {
4201-
return DAG.getNode(
4202-
ISD::AND, SL, VT,
4203-
DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(0), N->getOperand(1)),
4204-
DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(1), N->getOperand(1)));
4195+
// fold (srl (and x, c1 << c2), c2) -> (and (srl(x, c2), c1)
4196+
// this improves the ability to match BFE patterns in isel.
4197+
if (LHS.getOpcode() == ISD::AND) {
4198+
if (auto *Mask = dyn_cast<ConstantSDNode>(LHS.getOperand(1))) {
4199+
unsigned MaskIdx, MaskLen;
4200+
if (Mask->getAPIntValue().isShiftedMask(MaskIdx, MaskLen) &&
4201+
MaskIdx == RHSVal) {
4202+
return DAG.getNode(ISD::AND, SL, VT,
4203+
DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(0),
4204+
N->getOperand(1)),
4205+
DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(1),
4206+
N->getOperand(1)));
4207+
}
42054208
}
42064209
}
42074210
}
42084211

4209-
if (VT != MVT::i64)
4212+
if (VT.getScalarType() != MVT::i64)
42104213
return SDValue();
42114214

4212-
if (ShiftAmt < 32)
4215+
// for C >= 32
4216+
// i64 (srl x, C) -> (build_pair (srl hi_32(x), C -32), 0)
4217+
4218+
// On some subtargets, 64-bit shift is a quarter rate instruction. In the
4219+
// common case, splitting this into a move and a 32-bit shift is faster and
4220+
// the same code size.
4221+
KnownBits Known = DAG.computeKnownBits(RHS);
4222+
4223+
EVT ElementType = VT.getScalarType();
4224+
EVT TargetScalarType = ElementType.getHalfSizedIntegerVT(*DAG.getContext());
4225+
EVT TargetType = VT.isVector() ? VT.changeVectorElementType(TargetScalarType)
4226+
: TargetScalarType;
4227+
4228+
if (Known.getMinValue().getZExtValue() < TargetScalarType.getSizeInBits())
42134229
return SDValue();
42144230

4215-
// srl i64:x, C for C >= 32
4216-
// =>
4217-
// build_pair (srl hi_32(x), C - 32), 0
4218-
SDValue Zero = DAG.getConstant(0, SL, MVT::i32);
4231+
SDValue ShiftAmt;
4232+
if (CRHS) {
4233+
ShiftAmt = DAG.getConstant(RHSVal - TargetScalarType.getSizeInBits(), SL,
4234+
TargetType);
4235+
} else {
4236+
SDValue TruncShiftAmt = DAG.getNode(ISD::TRUNCATE, SL, TargetType, RHS);
4237+
const SDValue ShiftMask =
4238+
DAG.getConstant(TargetScalarType.getSizeInBits() - 1, SL, TargetType);
4239+
// This AND instruction will clamp out of bounds shift values.
4240+
// It will also be removed during later instruction selection.
4241+
ShiftAmt = DAG.getNode(ISD::AND, SL, TargetType, TruncShiftAmt, ShiftMask);
4242+
}
4243+
4244+
const SDValue Zero = DAG.getConstant(0, SL, TargetScalarType);
4245+
EVT ConcatType;
4246+
SDValue Hi;
4247+
SDLoc LHSSL(LHS);
4248+
// Bitcast LHS into ConcatType so hi-half of source can be extracted into Hi
4249+
if (VT.isVector()) {
4250+
unsigned NElts = TargetType.getVectorNumElements();
4251+
ConcatType = TargetType.getDoubleNumVectorElementsVT(*DAG.getContext());
4252+
SDValue SplitLHS = DAG.getNode(ISD::BITCAST, LHSSL, ConcatType, LHS);
4253+
SmallVector<SDValue, 8> HiOps(NElts);
4254+
SmallVector<SDValue, 16> HiAndLoOps;
42194255

4220-
SDValue Hi = getHiHalf64(LHS, DAG);
4256+
DAG.ExtractVectorElements(SplitLHS, HiAndLoOps, /*Start=*/0, NElts * 2);
4257+
for (unsigned I = 0; I != NElts; ++I)
4258+
HiOps[I] = HiAndLoOps[2 * I + 1];
4259+
Hi = DAG.getNode(ISD::BUILD_VECTOR, LHSSL, TargetType, HiOps);
4260+
} else {
4261+
const SDValue One = DAG.getConstant(1, LHSSL, TargetScalarType);
4262+
ConcatType = EVT::getVectorVT(*DAG.getContext(), TargetType, 2);
4263+
SDValue SplitLHS = DAG.getNode(ISD::BITCAST, LHSSL, ConcatType, LHS);
4264+
Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, LHSSL, TargetType, SplitLHS, One);
4265+
}
42214266

4222-
SDValue NewConst = DAG.getConstant(ShiftAmt - 32, SL, MVT::i32);
4223-
SDValue NewShift = DAG.getNode(ISD::SRL, SL, MVT::i32, Hi, NewConst);
4267+
SDValue NewShift = DAG.getNode(ISD::SRL, SL, TargetType, Hi, ShiftAmt);
42244268

4225-
SDValue BuildPair = DAG.getBuildVector(MVT::v2i32, SL, {NewShift, Zero});
4269+
SDValue Vec;
4270+
if (VT.isVector()) {
4271+
unsigned NElts = TargetType.getVectorNumElements();
4272+
SmallVector<SDValue, 8> LoOps;
4273+
SmallVector<SDValue, 16> HiAndLoOps(NElts * 2, Zero);
42264274

4227-
return DAG.getNode(ISD::BITCAST, SL, MVT::i64, BuildPair);
4275+
DAG.ExtractVectorElements(NewShift, LoOps, 0, NElts);
4276+
for (unsigned I = 0; I != NElts; ++I)
4277+
HiAndLoOps[2 * I] = LoOps[I];
4278+
Vec = DAG.getNode(ISD::BUILD_VECTOR, SL, ConcatType, HiAndLoOps);
4279+
} else {
4280+
Vec = DAG.getBuildVector(ConcatType, SL, {NewShift, Zero});
4281+
}
4282+
return DAG.getNode(ISD::BITCAST, SL, VT, Vec);
42284283
}
42294284

42304285
SDValue AMDGPUTargetLowering::performTruncateCombine(
@@ -5209,21 +5264,18 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
52095264

52105265
break;
52115266
}
5212-
case ISD::SHL: {
5267+
case ISD::SHL:
5268+
case ISD::SRL: {
52135269
// Range metadata can be invalidated when loads are converted to legal types
52145270
// (e.g. v2i64 -> v4i32).
5215-
// Try to convert vector shl before type legalization so that range metadata
5216-
// can be utilized.
5271+
// Try to convert vector shl/srl before type legalization so that range
5272+
// metadata can be utilized.
52175273
if (!(N->getValueType(0).isVector() &&
52185274
DCI.getDAGCombineLevel() == BeforeLegalizeTypes) &&
52195275
DCI.getDAGCombineLevel() < AfterLegalizeDAG)
52205276
break;
5221-
return performShlCombine(N, DCI);
5222-
}
5223-
case ISD::SRL: {
5224-
if (DCI.getDAGCombineLevel() < AfterLegalizeDAG)
5225-
break;
5226-
5277+
if (N->getOpcode() == ISD::SHL)
5278+
return performShlCombine(N, DCI);
52275279
return performSrlCombine(N, DCI);
52285280
}
52295281
case ISD::SRA: {

llvm/test/CodeGen/AMDGPU/mad_64_32.ll

Lines changed: 30 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1945,16 +1945,14 @@ define <2 x i64> @lshr_mad_i64_vec(<2 x i64> %arg0) #0 {
19451945
; CI-LABEL: lshr_mad_i64_vec:
19461946
; CI: ; %bb.0:
19471947
; CI-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
1948-
; CI-NEXT: v_mov_b32_e32 v6, v3
1949-
; CI-NEXT: v_mov_b32_e32 v3, v1
1950-
; CI-NEXT: v_mov_b32_e32 v1, 0
19511948
; CI-NEXT: s_mov_b32 s4, 0xffff1c18
1952-
; CI-NEXT: v_mad_u64_u32 v[4:5], s[4:5], v3, s4, v[0:1]
1953-
; CI-NEXT: v_mov_b32_e32 v3, v1
1949+
; CI-NEXT: v_mad_u64_u32 v[4:5], s[4:5], v1, s4, v[0:1]
19541950
; CI-NEXT: s_mov_b32 s4, 0xffff1118
1955-
; CI-NEXT: v_mad_u64_u32 v[2:3], s[4:5], v6, s4, v[2:3]
1951+
; CI-NEXT: v_mad_u64_u32 v[6:7], s[4:5], v3, s4, v[2:3]
1952+
; CI-NEXT: v_sub_i32_e32 v1, vcc, v5, v1
1953+
; CI-NEXT: v_sub_i32_e32 v3, vcc, v7, v3
19561954
; CI-NEXT: v_mov_b32_e32 v0, v4
1957-
; CI-NEXT: v_mov_b32_e32 v1, v5
1955+
; CI-NEXT: v_mov_b32_e32 v2, v6
19581956
; CI-NEXT: s_setpc_b64 s[30:31]
19591957
;
19601958
; SI-LABEL: lshr_mad_i64_vec:
@@ -1977,44 +1975,28 @@ define <2 x i64> @lshr_mad_i64_vec(<2 x i64> %arg0) #0 {
19771975
; GFX9-LABEL: lshr_mad_i64_vec:
19781976
; GFX9: ; %bb.0:
19791977
; GFX9-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
1980-
; GFX9-NEXT: v_mov_b32_e32 v6, v3
1981-
; GFX9-NEXT: v_mov_b32_e32 v3, v1
1982-
; GFX9-NEXT: v_mov_b32_e32 v1, 0
19831978
; GFX9-NEXT: s_mov_b32 s4, 0xffff1c18
1984-
; GFX9-NEXT: v_mad_u64_u32 v[4:5], s[4:5], v3, s4, v[0:1]
1985-
; GFX9-NEXT: v_mov_b32_e32 v3, v1
1979+
; GFX9-NEXT: v_mad_u64_u32 v[4:5], s[4:5], v1, s4, v[0:1]
19861980
; GFX9-NEXT: s_mov_b32 s4, 0xffff1118
1987-
; GFX9-NEXT: v_mad_u64_u32 v[2:3], s[4:5], v6, s4, v[2:3]
1981+
; GFX9-NEXT: v_mad_u64_u32 v[6:7], s[4:5], v3, s4, v[2:3]
1982+
; GFX9-NEXT: v_sub_u32_e32 v1, v5, v1
1983+
; GFX9-NEXT: v_sub_u32_e32 v3, v7, v3
19881984
; GFX9-NEXT: v_mov_b32_e32 v0, v4
1989-
; GFX9-NEXT: v_mov_b32_e32 v1, v5
1985+
; GFX9-NEXT: v_mov_b32_e32 v2, v6
19901986
; GFX9-NEXT: s_setpc_b64 s[30:31]
19911987
;
1992-
; GFX1100-LABEL: lshr_mad_i64_vec:
1993-
; GFX1100: ; %bb.0:
1994-
; GFX1100-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
1995-
; GFX1100-NEXT: v_mov_b32_e32 v8, v3
1996-
; GFX1100-NEXT: v_dual_mov_b32 v6, v1 :: v_dual_mov_b32 v1, 0
1997-
; GFX1100-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
1998-
; GFX1100-NEXT: v_mad_u64_u32 v[4:5], null, 0xffff1c18, v6, v[0:1]
1999-
; GFX1100-NEXT: v_dual_mov_b32 v3, v1 :: v_dual_mov_b32 v0, v4
2000-
; GFX1100-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
2001-
; GFX1100-NEXT: v_mad_u64_u32 v[6:7], null, 0xffff1118, v8, v[2:3]
2002-
; GFX1100-NEXT: v_dual_mov_b32 v1, v5 :: v_dual_mov_b32 v2, v6
2003-
; GFX1100-NEXT: s_delay_alu instid0(VALU_DEP_2)
2004-
; GFX1100-NEXT: v_mov_b32_e32 v3, v7
2005-
; GFX1100-NEXT: s_setpc_b64 s[30:31]
2006-
;
2007-
; GFX1150-LABEL: lshr_mad_i64_vec:
2008-
; GFX1150: ; %bb.0:
2009-
; GFX1150-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
2010-
; GFX1150-NEXT: v_dual_mov_b32 v4, v3 :: v_dual_mov_b32 v5, v1
2011-
; GFX1150-NEXT: v_mov_b32_e32 v1, 0
2012-
; GFX1150-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_3)
2013-
; GFX1150-NEXT: v_mov_b32_e32 v3, v1
2014-
; GFX1150-NEXT: v_mad_u64_u32 v[0:1], null, 0xffff1c18, v5, v[0:1]
2015-
; GFX1150-NEXT: s_delay_alu instid0(VALU_DEP_2)
2016-
; GFX1150-NEXT: v_mad_u64_u32 v[2:3], null, 0xffff1118, v4, v[2:3]
2017-
; GFX1150-NEXT: s_setpc_b64 s[30:31]
1988+
; GFX11-LABEL: lshr_mad_i64_vec:
1989+
; GFX11: ; %bb.0:
1990+
; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
1991+
; GFX11-NEXT: v_mad_u64_u32 v[4:5], null, 0xffff1c18, v1, v[0:1]
1992+
; GFX11-NEXT: v_mad_u64_u32 v[6:7], null, 0xffff1118, v3, v[2:3]
1993+
; GFX11-NEXT: s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_3)
1994+
; GFX11-NEXT: v_sub_nc_u32_e32 v1, v5, v1
1995+
; GFX11-NEXT: v_mov_b32_e32 v0, v4
1996+
; GFX11-NEXT: s_delay_alu instid0(VALU_DEP_3) | instskip(NEXT) | instid1(VALU_DEP_4)
1997+
; GFX11-NEXT: v_sub_nc_u32_e32 v3, v7, v3
1998+
; GFX11-NEXT: v_mov_b32_e32 v2, v6
1999+
; GFX11-NEXT: s_setpc_b64 s[30:31]
20182000
;
20192001
; GFX12-LABEL: lshr_mad_i64_vec:
20202002
; GFX12: ; %bb.0:
@@ -2023,13 +2005,14 @@ define <2 x i64> @lshr_mad_i64_vec(<2 x i64> %arg0) #0 {
20232005
; GFX12-NEXT: s_wait_samplecnt 0x0
20242006
; GFX12-NEXT: s_wait_bvhcnt 0x0
20252007
; GFX12-NEXT: s_wait_kmcnt 0x0
2026-
; GFX12-NEXT: v_dual_mov_b32 v4, v3 :: v_dual_mov_b32 v5, v1
2027-
; GFX12-NEXT: v_mov_b32_e32 v1, 0
2028-
; GFX12-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_3)
2029-
; GFX12-NEXT: v_mov_b32_e32 v3, v1
2030-
; GFX12-NEXT: v_mad_co_u64_u32 v[0:1], null, 0xffff1c18, v5, v[0:1]
2031-
; GFX12-NEXT: s_delay_alu instid0(VALU_DEP_2)
2032-
; GFX12-NEXT: v_mad_co_u64_u32 v[2:3], null, 0xffff1118, v4, v[2:3]
2008+
; GFX12-NEXT: v_mad_co_u64_u32 v[4:5], null, 0xffff1c18, v1, v[0:1]
2009+
; GFX12-NEXT: v_mad_co_u64_u32 v[6:7], null, 0xffff1118, v3, v[2:3]
2010+
; GFX12-NEXT: s_delay_alu instid0(VALU_DEP_2) | instskip(NEXT) | instid1(VALU_DEP_3)
2011+
; GFX12-NEXT: v_sub_nc_u32_e32 v1, v5, v1
2012+
; GFX12-NEXT: v_mov_b32_e32 v0, v4
2013+
; GFX12-NEXT: s_delay_alu instid0(VALU_DEP_3) | instskip(NEXT) | instid1(VALU_DEP_4)
2014+
; GFX12-NEXT: v_sub_nc_u32_e32 v3, v7, v3
2015+
; GFX12-NEXT: v_mov_b32_e32 v2, v6
20332016
; GFX12-NEXT: s_setpc_b64 s[30:31]
20342017
%lsh = lshr <2 x i64> %arg0, <i64 32, i64 32>
20352018
%mul = mul <2 x i64> %lsh, <i64 s0xffffffffffff1c18, i64 s0xffffffffffff1118>

0 commit comments

Comments
 (0)