Skip to content

Commit f2991bf

Browse files
authored
[AMDGPU] Convert 64-bit sra to 32-bit if shift amt >= 32 (llvm#144421)
Use KnownBits to convert 64-bit sra to 32-bit sra. Scaled-down alive2 verification with 16/8-bit types: https://alive2.llvm.org/ce/z/LamASk --------- Signed-off-by: John Lu <[email protected]>
1 parent 442c417 commit f2991bf

File tree

2 files changed

+200
-93
lines changed

2 files changed

+200
-93
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 87 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4151,32 +4151,96 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
41514151

41524152
SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,
41534153
DAGCombinerInfo &DCI) const {
4154-
if (N->getValueType(0) != MVT::i64)
4154+
SDValue RHS = N->getOperand(1);
4155+
ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
4156+
EVT VT = N->getValueType(0);
4157+
SDValue LHS = N->getOperand(0);
4158+
SelectionDAG &DAG = DCI.DAG;
4159+
SDLoc SL(N);
4160+
4161+
if (VT.getScalarType() != MVT::i64)
41554162
return SDValue();
41564163

4157-
const ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(N->getOperand(1));
4158-
if (!RHS)
4164+
// For C >= 32
4165+
// i64 (sra x, C) -> (build_pair (sra hi_32(x), C - 32), sra hi_32(x), 31))
4166+
4167+
// On some subtargets, 64-bit shift is a quarter rate instruction. In the
4168+
// common case, splitting this into a move and a 32-bit shift is faster and
4169+
// the same code size.
4170+
KnownBits Known = DAG.computeKnownBits(RHS);
4171+
4172+
EVT ElementType = VT.getScalarType();
4173+
EVT TargetScalarType = ElementType.getHalfSizedIntegerVT(*DAG.getContext());
4174+
EVT TargetType = VT.isVector() ? VT.changeVectorElementType(TargetScalarType)
4175+
: TargetScalarType;
4176+
4177+
if (Known.getMinValue().getZExtValue() < TargetScalarType.getSizeInBits())
41594178
return SDValue();
41604179

4161-
SelectionDAG &DAG = DCI.DAG;
4162-
SDLoc SL(N);
4163-
unsigned RHSVal = RHS->getZExtValue();
4180+
SDValue ShiftFullAmt =
4181+
DAG.getConstant(TargetScalarType.getSizeInBits() - 1, SL, TargetType);
4182+
SDValue ShiftAmt;
4183+
if (CRHS) {
4184+
unsigned RHSVal = CRHS->getZExtValue();
4185+
ShiftAmt = DAG.getConstant(RHSVal - TargetScalarType.getSizeInBits(), SL,
4186+
TargetType);
4187+
} else if (Known.getMinValue().getZExtValue() ==
4188+
(ElementType.getSizeInBits() - 1)) {
4189+
ShiftAmt = ShiftFullAmt;
4190+
} else {
4191+
SDValue truncShiftAmt = DAG.getNode(ISD::TRUNCATE, SL, TargetType, RHS);
4192+
const SDValue ShiftMask =
4193+
DAG.getConstant(TargetScalarType.getSizeInBits() - 1, SL, TargetType);
4194+
// This AND instruction will clamp out of bounds shift values.
4195+
// It will also be removed during later instruction selection.
4196+
ShiftAmt = DAG.getNode(ISD::AND, SL, TargetType, truncShiftAmt, ShiftMask);
4197+
}
41644198

4165-
// For C >= 32
4166-
// (sra i64:x, C) -> build_pair (sra hi_32(x), C - 32), (sra hi_32(x), 31)
4167-
if (RHSVal >= 32) {
4168-
SDValue Hi = getHiHalf64(N->getOperand(0), DAG);
4169-
Hi = DAG.getFreeze(Hi);
4170-
SDValue HiShift = DAG.getNode(ISD::SRA, SL, MVT::i32, Hi,
4171-
DAG.getConstant(31, SL, MVT::i32));
4172-
SDValue LoShift = DAG.getNode(ISD::SRA, SL, MVT::i32, Hi,
4173-
DAG.getConstant(RHSVal - 32, SL, MVT::i32));
4199+
EVT ConcatType;
4200+
SDValue Hi;
4201+
SDLoc LHSSL(LHS);
4202+
// Bitcast LHS into ConcatType so hi-half of source can be extracted into Hi
4203+
if (VT.isVector()) {
4204+
unsigned NElts = TargetType.getVectorNumElements();
4205+
ConcatType = TargetType.getDoubleNumVectorElementsVT(*DAG.getContext());
4206+
SDValue SplitLHS = DAG.getNode(ISD::BITCAST, LHSSL, ConcatType, LHS);
4207+
SmallVector<SDValue, 8> HiOps(NElts);
4208+
SmallVector<SDValue, 16> HiAndLoOps;
41744209

4175-
SDValue BuildVec = DAG.getBuildVector(MVT::v2i32, SL, {LoShift, HiShift});
4176-
return DAG.getNode(ISD::BITCAST, SL, MVT::i64, BuildVec);
4210+
DAG.ExtractVectorElements(SplitLHS, HiAndLoOps, 0, NElts * 2);
4211+
for (unsigned I = 0; I != NElts; ++I) {
4212+
HiOps[I] = HiAndLoOps[2 * I + 1];
4213+
}
4214+
Hi = DAG.getNode(ISD::BUILD_VECTOR, LHSSL, TargetType, HiOps);
4215+
} else {
4216+
const SDValue One = DAG.getConstant(1, LHSSL, TargetScalarType);
4217+
ConcatType = EVT::getVectorVT(*DAG.getContext(), TargetType, 2);
4218+
SDValue SplitLHS = DAG.getNode(ISD::BITCAST, LHSSL, ConcatType, LHS);
4219+
Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, LHSSL, TargetType, SplitLHS, One);
41774220
}
4221+
Hi = DAG.getFreeze(Hi);
41784222

4179-
return SDValue();
4223+
SDValue HiShift = DAG.getNode(ISD::SRA, SL, TargetType, Hi, ShiftFullAmt);
4224+
SDValue NewShift = DAG.getNode(ISD::SRA, SL, TargetType, Hi, ShiftAmt);
4225+
4226+
SDValue Vec;
4227+
if (VT.isVector()) {
4228+
unsigned NElts = TargetType.getVectorNumElements();
4229+
SmallVector<SDValue, 8> HiOps;
4230+
SmallVector<SDValue, 8> LoOps;
4231+
SmallVector<SDValue, 16> HiAndLoOps(NElts * 2);
4232+
4233+
DAG.ExtractVectorElements(HiShift, HiOps, 0, NElts);
4234+
DAG.ExtractVectorElements(NewShift, LoOps, 0, NElts);
4235+
for (unsigned I = 0; I != NElts; ++I) {
4236+
HiAndLoOps[2 * I + 1] = HiOps[I];
4237+
HiAndLoOps[2 * I] = LoOps[I];
4238+
}
4239+
Vec = DAG.getNode(ISD::BUILD_VECTOR, SL, ConcatType, HiAndLoOps);
4240+
} else {
4241+
Vec = DAG.getBuildVector(ConcatType, SL, {NewShift, HiShift});
4242+
}
4243+
return DAG.getNode(ISD::BITCAST, SL, VT, Vec);
41804244
}
41814245

41824246
SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
@@ -4213,7 +4277,7 @@ SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
42134277
return SDValue();
42144278

42154279
// for C >= 32
4216-
// i64 (srl x, C) -> (build_pair (srl hi_32(x), C -32), 0)
4280+
// i64 (srl x, C) -> (build_pair (srl hi_32(x), C - 32), 0)
42174281

42184282
// On some subtargets, 64-bit shift is a quarter rate instruction. In the
42194283
// common case, splitting this into a move and a 32-bit shift is faster and
@@ -5265,25 +5329,22 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
52655329
break;
52665330
}
52675331
case ISD::SHL:
5332+
case ISD::SRA:
52685333
case ISD::SRL: {
52695334
// Range metadata can be invalidated when loads are converted to legal types
52705335
// (e.g. v2i64 -> v4i32).
5271-
// Try to convert vector shl/srl before type legalization so that range
5336+
// Try to convert vector shl/sra/srl before type legalization so that range
52725337
// metadata can be utilized.
52735338
if (!(N->getValueType(0).isVector() &&
52745339
DCI.getDAGCombineLevel() == BeforeLegalizeTypes) &&
52755340
DCI.getDAGCombineLevel() < AfterLegalizeDAG)
52765341
break;
52775342
if (N->getOpcode() == ISD::SHL)
52785343
return performShlCombine(N, DCI);
5344+
if (N->getOpcode() == ISD::SRA)
5345+
return performSraCombine(N, DCI);
52795346
return performSrlCombine(N, DCI);
52805347
}
5281-
case ISD::SRA: {
5282-
if (DCI.getDAGCombineLevel() < AfterLegalizeDAG)
5283-
break;
5284-
5285-
return performSraCombine(N, DCI);
5286-
}
52875348
case ISD::TRUNCATE:
52885349
return performTruncateCombine(N, DCI);
52895350
case ISD::MUL:

0 commit comments

Comments
 (0)