Skip to content

Commit 3e55ac9

Browse files
authored
[RISCV] Strength reduce mul by 2^N - 2^M (#88983)
This is a three instruction expansion, and does not depend on zba, so most of the test changes are in base RV32/64I configurations. With zba, this gets immediates such as 14, 28, 30, 56, 60, 62.. which aren't covered by our other expansions.
1 parent 67f5312 commit 3e55ac9

File tree

14 files changed

+422
-374
lines changed

14 files changed

+422
-374
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 107 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -13706,8 +13706,8 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
1370613706
if (VT != Subtarget.getXLenVT())
1370713707
return SDValue();
1370813708

13709-
if (!Subtarget.hasStdExtZba() && !Subtarget.hasVendorXTHeadBa())
13710-
return SDValue();
13709+
const bool HasShlAdd =
13710+
Subtarget.hasStdExtZba() || Subtarget.hasVendorXTHeadBa();
1371113711

1371213712
ConstantSDNode *CNode = dyn_cast<ConstantSDNode>(N->getOperand(1));
1371313713
if (!CNode)
@@ -13720,107 +13720,123 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
1372013720
// other target properly freezes X in these cases either.
1372113721
SDValue X = N->getOperand(0);
1372213722

13723-
for (uint64_t Divisor : {3, 5, 9}) {
13724-
if (MulAmt % Divisor != 0)
13725-
continue;
13726-
uint64_t MulAmt2 = MulAmt / Divisor;
13727-
// 3/5/9 * 2^N -> shl (shXadd X, X), N
13728-
if (isPowerOf2_64(MulAmt2)) {
13729-
SDLoc DL(N);
13730-
SDValue X = N->getOperand(0);
13731-
// Put the shift first if we can fold a zext into the
13732-
// shift forming a slli.uw.
13733-
if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
13734-
X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
13735-
SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X,
13736-
DAG.getConstant(Log2_64(MulAmt2), DL, VT));
13737-
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
13738-
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), Shl);
13723+
if (HasShlAdd) {
13724+
for (uint64_t Divisor : {3, 5, 9}) {
13725+
if (MulAmt % Divisor != 0)
13726+
continue;
13727+
uint64_t MulAmt2 = MulAmt / Divisor;
13728+
// 3/5/9 * 2^N -> shl (shXadd X, X), N
13729+
if (isPowerOf2_64(MulAmt2)) {
13730+
SDLoc DL(N);
13731+
SDValue X = N->getOperand(0);
13732+
// Put the shift first if we can fold a zext into the
13733+
// shift forming a slli.uw.
13734+
if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
13735+
X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
13736+
SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X,
13737+
DAG.getConstant(Log2_64(MulAmt2), DL, VT));
13738+
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
13739+
DAG.getConstant(Log2_64(Divisor - 1), DL, VT),
13740+
Shl);
13741+
}
13742+
// Otherwise, put rhe shl second so that it can fold with following
13743+
// instructions (e.g. sext or add).
13744+
SDValue Mul359 =
13745+
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13746+
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13747+
return DAG.getNode(ISD::SHL, DL, VT, Mul359,
13748+
DAG.getConstant(Log2_64(MulAmt2), DL, VT));
13749+
}
13750+
13751+
// 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
13752+
if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) {
13753+
SDLoc DL(N);
13754+
SDValue Mul359 =
13755+
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13756+
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13757+
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
13758+
DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT),
13759+
Mul359);
1373913760
}
13740-
// Otherwise, put rhe shl second so that it can fold with following
13741-
// instructions (e.g. sext or add).
13742-
SDValue Mul359 =
13743-
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13744-
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13745-
return DAG.getNode(ISD::SHL, DL, VT, Mul359,
13746-
DAG.getConstant(Log2_64(MulAmt2), DL, VT));
1374713761
}
1374813762

13749-
// 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
13750-
if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) {
13751-
SDLoc DL(N);
13752-
SDValue Mul359 =
13753-
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13754-
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13755-
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
13756-
DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT),
13757-
Mul359);
13758-
}
13759-
}
13760-
13761-
// If this is a power 2 + 2/4/8, we can use a shift followed by a single
13762-
// shXadd. First check if this a sum of two power of 2s because that's
13763-
// easy. Then count how many zeros are up to the first bit.
13764-
if (isPowerOf2_64(MulAmt & (MulAmt - 1))) {
13765-
unsigned ScaleShift = llvm::countr_zero(MulAmt);
13766-
if (ScaleShift >= 1 && ScaleShift < 4) {
13767-
unsigned ShiftAmt = Log2_64((MulAmt & (MulAmt - 1)));
13768-
SDLoc DL(N);
13769-
SDValue Shift1 =
13770-
DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
13771-
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13772-
DAG.getConstant(ScaleShift, DL, VT), Shift1);
13763+
// If this is a power 2 + 2/4/8, we can use a shift followed by a single
13764+
// shXadd. First check if this a sum of two power of 2s because that's
13765+
// easy. Then count how many zeros are up to the first bit.
13766+
if (isPowerOf2_64(MulAmt & (MulAmt - 1))) {
13767+
unsigned ScaleShift = llvm::countr_zero(MulAmt);
13768+
if (ScaleShift >= 1 && ScaleShift < 4) {
13769+
unsigned ShiftAmt = Log2_64((MulAmt & (MulAmt - 1)));
13770+
SDLoc DL(N);
13771+
SDValue Shift1 =
13772+
DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
13773+
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13774+
DAG.getConstant(ScaleShift, DL, VT), Shift1);
13775+
}
1377313776
}
13774-
}
1377513777

13776-
// 2^(1,2,3) * 3,5,9 + 1 -> (shXadd (shYadd x, x), x)
13777-
// This is the two instruction form, there are also three instruction
13778-
// variants we could implement. e.g.
13779-
// (2^(1,2,3) * 3,5,9 + 1) << C2
13780-
// 2^(C1>3) * 3,5,9 +/- 1
13781-
for (uint64_t Divisor : {3, 5, 9}) {
13782-
uint64_t C = MulAmt - 1;
13783-
if (C <= Divisor)
13784-
continue;
13785-
unsigned TZ = llvm::countr_zero(C);
13786-
if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) {
13787-
SDLoc DL(N);
13788-
SDValue Mul359 =
13789-
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13790-
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13791-
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
13792-
DAG.getConstant(TZ, DL, VT), X);
13778+
// 2^(1,2,3) * 3,5,9 + 1 -> (shXadd (shYadd x, x), x)
13779+
// This is the two instruction form, there are also three instruction
13780+
// variants we could implement. e.g.
13781+
// (2^(1,2,3) * 3,5,9 + 1) << C2
13782+
// 2^(C1>3) * 3,5,9 +/- 1
13783+
for (uint64_t Divisor : {3, 5, 9}) {
13784+
uint64_t C = MulAmt - 1;
13785+
if (C <= Divisor)
13786+
continue;
13787+
unsigned TZ = llvm::countr_zero(C);
13788+
if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) {
13789+
SDLoc DL(N);
13790+
SDValue Mul359 =
13791+
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13792+
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13793+
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
13794+
DAG.getConstant(TZ, DL, VT), X);
13795+
}
1379313796
}
13794-
}
1379513797

13796-
// 2^n + 2/4/8 + 1 -> (add (shl X, C1), (shXadd X, X))
13797-
if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) {
13798-
unsigned ScaleShift = llvm::countr_zero(MulAmt - 1);
13799-
if (ScaleShift >= 1 && ScaleShift < 4) {
13800-
unsigned ShiftAmt = Log2_64(((MulAmt - 1) & (MulAmt - 2)));
13801-
SDLoc DL(N);
13802-
SDValue Shift1 =
13803-
DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
13804-
return DAG.getNode(ISD::ADD, DL, VT, Shift1,
13805-
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13806-
DAG.getConstant(ScaleShift, DL, VT), X));
13798+
// 2^n + 2/4/8 + 1 -> (add (shl X, C1), (shXadd X, X))
13799+
if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) {
13800+
unsigned ScaleShift = llvm::countr_zero(MulAmt - 1);
13801+
if (ScaleShift >= 1 && ScaleShift < 4) {
13802+
unsigned ShiftAmt = Log2_64(((MulAmt - 1) & (MulAmt - 2)));
13803+
SDLoc DL(N);
13804+
SDValue Shift1 =
13805+
DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
13806+
return DAG.getNode(ISD::ADD, DL, VT, Shift1,
13807+
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13808+
DAG.getConstant(ScaleShift, DL, VT), X));
13809+
}
1380713810
}
13808-
}
1380913811

13810-
// 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x))
13811-
for (uint64_t Offset : {3, 5, 9}) {
13812-
if (isPowerOf2_64(MulAmt + Offset)) {
13813-
SDLoc DL(N);
13814-
SDValue Shift1 =
13815-
DAG.getNode(ISD::SHL, DL, VT, X,
13816-
DAG.getConstant(Log2_64(MulAmt + Offset), DL, VT));
13817-
SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13818-
DAG.getConstant(Log2_64(Offset - 1), DL, VT),
13819-
X);
13820-
return DAG.getNode(ISD::SUB, DL, VT, Shift1, Mul359);
13812+
// 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x))
13813+
for (uint64_t Offset : {3, 5, 9}) {
13814+
if (isPowerOf2_64(MulAmt + Offset)) {
13815+
SDLoc DL(N);
13816+
SDValue Shift1 =
13817+
DAG.getNode(ISD::SHL, DL, VT, X,
13818+
DAG.getConstant(Log2_64(MulAmt + Offset), DL, VT));
13819+
SDValue Mul359 =
13820+
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13821+
DAG.getConstant(Log2_64(Offset - 1), DL, VT), X);
13822+
return DAG.getNode(ISD::SUB, DL, VT, Shift1, Mul359);
13823+
}
1382113824
}
1382213825
}
1382313826

13827+
// 2^N - 2^M -> (sub (shl X, C1), (shl X, C2))
13828+
uint64_t MulAmtLowBit = MulAmt & (-MulAmt);
13829+
if (isPowerOf2_64(MulAmt + MulAmtLowBit)) {
13830+
uint64_t ShiftAmt1 = MulAmt + MulAmtLowBit;
13831+
SDLoc DL(N);
13832+
SDValue Shift1 = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
13833+
DAG.getConstant(Log2_64(ShiftAmt1), DL, VT));
13834+
SDValue Shift2 =
13835+
DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
13836+
DAG.getConstant(Log2_64(MulAmtLowBit), DL, VT));
13837+
return DAG.getNode(ISD::SUB, DL, VT, Shift1, Shift2);
13838+
}
13839+
1382413840
return SDValue();
1382513841
}
1382613842

0 commit comments

Comments
 (0)