Skip to content

AMDGPU: Reduce shl64 to shl32 if shift range is [63-32] #125574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 13, 2025
Merged
101 changes: 59 additions & 42 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4040,47 +4040,48 @@ SDValue AMDGPUTargetLowering::splitBinaryBitConstantOpImpl(
SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
EVT VT = N->getValueType(0);

ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(N->getOperand(1));
if (!RHS)
return SDValue();

SDValue LHS = N->getOperand(0);
unsigned RHSVal = RHS->getZExtValue();
if (!RHSVal)
return LHS;

SDValue RHS = N->getOperand(1);
ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
SDLoc SL(N);
SelectionDAG &DAG = DCI.DAG;

switch (LHS->getOpcode()) {
default:
break;
case ISD::ZERO_EXTEND:
case ISD::SIGN_EXTEND:
case ISD::ANY_EXTEND: {
SDValue X = LHS->getOperand(0);

if (VT == MVT::i32 && RHSVal == 16 && X.getValueType() == MVT::i16 &&
isOperationLegal(ISD::BUILD_VECTOR, MVT::v2i16)) {
// Prefer build_vector as the canonical form if packed types are legal.
// (shl ([asz]ext i16:x), 16 -> build_vector 0, x
SDValue Vec = DAG.getBuildVector(MVT::v2i16, SL,
{ DAG.getConstant(0, SL, MVT::i16), LHS->getOperand(0) });
return DAG.getNode(ISD::BITCAST, SL, MVT::i32, Vec);
}
unsigned RHSVal;
if (CRHS) {
RHSVal = CRHS->getZExtValue();
if (!RHSVal)
return LHS;

// shl (ext x) => zext (shl x), if shift does not overflow int
if (VT != MVT::i64)
break;
KnownBits Known = DAG.computeKnownBits(X);
unsigned LZ = Known.countMinLeadingZeros();
if (LZ < RHSVal)
switch (LHS->getOpcode()) {
default:
break;
EVT XVT = X.getValueType();
SDValue Shl = DAG.getNode(ISD::SHL, SL, XVT, X, SDValue(RHS, 0));
return DAG.getZExtOrTrunc(Shl, SL, VT);
}
case ISD::ZERO_EXTEND:
case ISD::SIGN_EXTEND:
case ISD::ANY_EXTEND: {
SDValue X = LHS->getOperand(0);

if (VT == MVT::i32 && RHSVal == 16 && X.getValueType() == MVT::i16 &&
isOperationLegal(ISD::BUILD_VECTOR, MVT::v2i16)) {
// Prefer build_vector as the canonical form if packed types are legal.
// (shl ([asz]ext i16:x), 16 -> build_vector 0, x
SDValue Vec = DAG.getBuildVector(
MVT::v2i16, SL,
{DAG.getConstant(0, SL, MVT::i16), LHS->getOperand(0)});
return DAG.getNode(ISD::BITCAST, SL, MVT::i32, Vec);
}

// shl (ext x) => zext (shl x), if shift does not overflow int
if (VT != MVT::i64)
break;
KnownBits Known = DAG.computeKnownBits(X);
unsigned LZ = Known.countMinLeadingZeros();
if (LZ < RHSVal)
break;
EVT XVT = X.getValueType();
SDValue Shl = DAG.getNode(ISD::SHL, SL, XVT, X, SDValue(CRHS, 0));
return DAG.getZExtOrTrunc(Shl, SL, VT);
}
}
}

if (VT != MVT::i64)
Expand All @@ -4091,18 +4092,34 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
// On some subtargets, 64-bit shift is a quarter rate instruction. In the
// common case, splitting this into a move and a 32-bit shift is faster and
// the same code size.
if (RHSVal < 32)
EVT TargetType = VT.getHalfSizedIntegerVT(*DAG.getContext());
EVT TargetVecPairType = EVT::getVectorVT(*DAG.getContext(), TargetType, 2);
KnownBits Known = DAG.computeKnownBits(RHS);

if (Known.getMinValue().getZExtValue() < TargetType.getSizeInBits())
return SDValue();
SDValue ShiftAmt;

SDValue ShiftAmt = DAG.getConstant(RHSVal - 32, SL, MVT::i32);
if (CRHS) {
ShiftAmt =
DAG.getConstant(RHSVal - TargetType.getSizeInBits(), SL, TargetType);
} else {
SDValue truncShiftAmt = DAG.getNode(ISD::TRUNCATE, SL, TargetType, RHS);
const SDValue ShiftMask =
DAG.getConstant(TargetType.getSizeInBits() - 1, SL, TargetType);
// This AND instruction will clamp out of bounds shift values.
// It will also be removed during later instruction selection.
ShiftAmt = DAG.getNode(ISD::AND, SL, TargetType, truncShiftAmt, ShiftMask);
}

SDValue Lo = DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, LHS);
SDValue NewShift = DAG.getNode(ISD::SHL, SL, MVT::i32, Lo, ShiftAmt);
SDValue Lo = DAG.getNode(ISD::TRUNCATE, SL, TargetType, LHS);
SDValue NewShift =
DAG.getNode(ISD::SHL, SL, TargetType, Lo, ShiftAmt, N->getFlags());

const SDValue Zero = DAG.getConstant(0, SL, MVT::i32);
const SDValue Zero = DAG.getConstant(0, SL, TargetType);

SDValue Vec = DAG.getBuildVector(MVT::v2i32, SL, {Zero, NewShift});
return DAG.getNode(ISD::BITCAST, SL, MVT::i64, Vec);
SDValue Vec = DAG.getBuildVector(TargetVecPairType, SL, {Zero, NewShift});
return DAG.getNode(ISD::BITCAST, SL, VT, Vec);
}

SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,
Expand Down
Loading