Skip to content

Commit 1af67fb

Browse files
committed
Use factorization
1 parent faac57c commit 1af67fb

File tree

7 files changed

+400
-452
lines changed

7 files changed

+400
-452
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 82 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "llvm/CodeGen/MachineRegisterInfo.h"
3333
#include "llvm/CodeGen/SDPatternMatch.h"
3434
#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
35+
#include "llvm/CodeGen/SelectionDAGNodes.h"
3536
#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
3637
#include "llvm/CodeGen/ValueTypes.h"
3738
#include "llvm/IR/DiagnosticInfo.h"
@@ -49,6 +50,7 @@
4950
#include "llvm/Support/KnownBits.h"
5051
#include "llvm/Support/MathExtras.h"
5152
#include "llvm/Support/raw_ostream.h"
53+
#include <cstdint>
5254
#include <optional>
5355

5456
using namespace llvm;
@@ -15437,15 +15439,10 @@ static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG,
1543715439
return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget);
1543815440
}
1543915441

15440-
// Try to expand a multiply to a sequence of shifts and add/subs,
15441-
// for a machine w/o native mul instruction.
15442-
static SDValue expandMulToBasicOps(SDNode *N, SelectionDAG &DAG,
15443-
uint64_t MulAmt) {
15444-
const uint64_t BitWidth = N->getValueType(0).getFixedSizeInBits();
15445-
SDLoc DL(N);
15446-
15447-
if (MulAmt == 0)
15448-
return DAG.getConstant(0, DL, N->getValueType(0));
15442+
static SDValue expandMulToNAFSequence(SDNode *N, SelectionDAG &DAG,
15443+
const SDLoc &DL, uint64_t MulAmt) {
15444+
EVT VT = N->getValueType(0);
15445+
const uint64_t BitWidth = VT.getFixedSizeInBits();
1544915446

1545015447
// Find the Non-adjacent form of the multiplier.
1545115448
llvm::SmallVector<std::pair<bool, uint64_t>> Sequence; // {isAdd, shamt}
@@ -15470,17 +15467,89 @@ static SDValue expandMulToBasicOps(SDNode *N, SelectionDAG &DAG,
1547015467
SDValue ShiftVal;
1547115468
if (Op.second > 0)
1547215469
ShiftVal =
15473-
DAG.getNode(ISD::SHL, DL, N->getValueType(0), N0,
15474-
DAG.getConstant(Op.second, DL, N->getValueType(0)));
15470+
DAG.getNode(ISD::SHL, DL, VT, N0, DAG.getConstant(Op.second, DL, VT));
1547515471
else
1547615472
ShiftVal = N0;
1547715473

1547815474
ISD::NodeType AddSubOp = Op.first ? ISD::ADD : ISD::SUB;
15479-
Result = DAG.getNode(AddSubOp, DL, N->getValueType(0), Result, ShiftVal);
15475+
Result = DAG.getNode(AddSubOp, DL, VT, Result, ShiftVal);
1548015476
}
15481-
1548215477
return Result;
1548315478
}
15479+
// Try to expand a multiply to a sequence of shifts and add/subs,
15480+
// for a machine w/o native mul instruction.
15481+
static SDValue expandMulToBasicOps(SDNode *N, SelectionDAG &DAG,
15482+
uint64_t MulAmt) {
15483+
EVT VT = N->getValueType(0);
15484+
const uint64_t BitWidth = VT.getFixedSizeInBits();
15485+
SDLoc DL(N);
15486+
15487+
if (MulAmt == 0)
15488+
return DAG.getConstant(0, DL, N->getValueType(0));
15489+
15490+
// Try to factorize into (2^N) * (2^M_1 +/- 1) + (2^M_2 +/- 1) + ...
15491+
uint64_t E = MulAmt;
15492+
uint64_t TrailingZeros = 0;
15493+
15494+
while (E > 0 && (E & 1) == 0) {
15495+
E >>= 1;
15496+
TrailingZeros++;
15497+
}
15498+
15499+
llvm::SmallVector<std::pair<bool, uint64_t>> Factors; // {is_2^M+1, M}
15500+
15501+
while (E > 1) {
15502+
bool Found = false;
15503+
for (int64_t I = BitWidth - 1; I >= 2; --I) {
15504+
uint64_t Factor = 1ULL << I;
15505+
15506+
if (E % (Factor - 1) == 0) {
15507+
Factors.push_back({false, I});
15508+
E /= Factor - 1;
15509+
Found = true;
15510+
break;
15511+
}
15512+
15513+
if (E % (Factor + 1) == 0) {
15514+
Factors.push_back({true, I});
15515+
E /= Factor + 1;
15516+
Found = true;
15517+
break;
15518+
}
15519+
}
15520+
if (!Found)
15521+
break;
15522+
}
15523+
15524+
SDValue Result;
15525+
SDValue N0 = N->getOperand(0);
15526+
15527+
bool UseFactorization =
15528+
!Factors.empty() && (E < MulAmt) && (Factors.size() < 5);
15529+
15530+
if (UseFactorization) {
15531+
if (E == 1)
15532+
Result = N0;
15533+
else
15534+
Result = expandMulToNAFSequence(N, DAG, DL, E);
15535+
15536+
for (const auto &F : Factors) {
15537+
SDValue ShiftVal = DAG.getNode(ISD::SHL, DL, VT, Result,
15538+
DAG.getConstant(F.second, DL, VT));
15539+
15540+
ISD::NodeType AddSubOp = F.first ? ISD::ADD : ISD::SUB;
15541+
Result = DAG.getNode(AddSubOp, DL, N->getValueType(0), ShiftVal, Result);
15542+
}
15543+
15544+
if (TrailingZeros > 0)
15545+
Result = DAG.getNode(ISD::SHL, DL, VT, Result,
15546+
DAG.getConstant(TrailingZeros, DL, VT));
15547+
15548+
return Result;
15549+
}
15550+
15551+
return expandMulToNAFSequence(N, DAG, DL, MulAmt);
15552+
}
1548415553

1548515554
// 2^N +/- 2^M -> (add/sub (shl X, C1), (shl X, C2))
1548615555
static SDValue expandMulToAddOrSubOfShl(SDNode *N, SelectionDAG &DAG,

0 commit comments

Comments
 (0)