Skip to content

Commit f7aad60

Browse files
authored
[RISCV] Fold vector shift of sext/zext to widening multiply (#121563)
(shl (sext X), C) -> (vwmulsu X, 1u << C) (shl (zext X), C) -> (vwmulu X, 1u << C)
1 parent d5488f1 commit f7aad60

12 files changed

+1902
-1800
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17759,6 +17759,83 @@ static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
1775917759
return DAG.getZExtOrTrunc(Pop, DL, VT);
1776017760
}
1776117761

17762+
static SDValue performSHLCombine(SDNode *N,
17763+
TargetLowering::DAGCombinerInfo &DCI,
17764+
const RISCVSubtarget &Subtarget) {
17765+
// (shl (zext x), y) -> (vwsll x, y)
17766+
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
17767+
return V;
17768+
17769+
// (shl (sext x), C) -> (vwmulsu x, 1u << C)
17770+
// (shl (zext x), C) -> (vwmulu x, 1u << C)
17771+
17772+
if (!DCI.isAfterLegalizeDAG())
17773+
return SDValue();
17774+
17775+
SDValue LHS = N->getOperand(0);
17776+
if (!LHS.hasOneUse())
17777+
return SDValue();
17778+
unsigned Opcode;
17779+
switch (LHS.getOpcode()) {
17780+
case ISD::SIGN_EXTEND:
17781+
case RISCVISD::VSEXT_VL:
17782+
Opcode = RISCVISD::VWMULSU_VL;
17783+
break;
17784+
case ISD::ZERO_EXTEND:
17785+
case RISCVISD::VZEXT_VL:
17786+
Opcode = RISCVISD::VWMULU_VL;
17787+
break;
17788+
default:
17789+
return SDValue();
17790+
}
17791+
17792+
SDValue RHS = N->getOperand(1);
17793+
APInt ShAmt;
17794+
uint64_t ShAmtInt;
17795+
if (ISD::isConstantSplatVector(RHS.getNode(), ShAmt))
17796+
ShAmtInt = ShAmt.getZExtValue();
17797+
else if (RHS.getOpcode() == RISCVISD::VMV_V_X_VL &&
17798+
RHS.getOperand(1).getOpcode() == ISD::Constant)
17799+
ShAmtInt = RHS.getConstantOperandVal(1);
17800+
else
17801+
return SDValue();
17802+
17803+
// Better foldings:
17804+
// (shl (sext x), 1) -> (vwadd x, x)
17805+
// (shl (zext x), 1) -> (vwaddu x, x)
17806+
if (ShAmtInt <= 1)
17807+
return SDValue();
17808+
17809+
SDValue NarrowOp = LHS.getOperand(0);
17810+
MVT NarrowVT = NarrowOp.getSimpleValueType();
17811+
uint64_t NarrowBits = NarrowVT.getScalarSizeInBits();
17812+
if (ShAmtInt >= NarrowBits)
17813+
return SDValue();
17814+
MVT VT = N->getSimpleValueType(0);
17815+
if (NarrowBits * 2 != VT.getScalarSizeInBits())
17816+
return SDValue();
17817+
17818+
SelectionDAG &DAG = DCI.DAG;
17819+
SDLoc DL(N);
17820+
SDValue Passthru, Mask, VL;
17821+
switch (N->getOpcode()) {
17822+
case ISD::SHL:
17823+
Passthru = DAG.getUNDEF(VT);
17824+
std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
17825+
break;
17826+
case RISCVISD::SHL_VL:
17827+
Passthru = N->getOperand(2);
17828+
Mask = N->getOperand(3);
17829+
VL = N->getOperand(4);
17830+
break;
17831+
default:
17832+
llvm_unreachable("Expected SHL");
17833+
}
17834+
return DAG.getNode(Opcode, DL, VT, NarrowOp,
17835+
DAG.getConstant(1ULL << ShAmtInt, SDLoc(RHS), NarrowVT),
17836+
Passthru, Mask, VL);
17837+
}
17838+
1776217839
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1776317840
DAGCombinerInfo &DCI) const {
1776417841
SelectionDAG &DAG = DCI.DAG;
@@ -18392,7 +18469,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1839218469
break;
1839318470
}
1839418471
case RISCVISD::SHL_VL:
18395-
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
18472+
if (SDValue V = performSHLCombine(N, DCI, Subtarget))
1839618473
return V;
1839718474
[[fallthrough]];
1839818475
case RISCVISD::SRA_VL:
@@ -18417,7 +18494,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1841718494
case ISD::SRL:
1841818495
case ISD::SHL: {
1841918496
if (N->getOpcode() == ISD::SHL) {
18420-
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
18497+
if (SDValue V = performSHLCombine(N, DCI, Subtarget))
1842118498
return V;
1842218499
}
1842318500
SDValue ShAmt = N->getOperand(1);

0 commit comments

Comments
 (0)