Skip to content

Commit 39f6d01

Browse files
preameslukel97
andauthored
[RISCV] Eliminate getVLENFactoredAmount and expose muladd [nfc] (#87881)
This restructures the code to make the fact that most of getVLENFactoredAmount is just a generic multiply w/immediate more obvious and prepare for a couple of upcoming enhancements to this code. Note that I plan to switch mulImm to early return, but decided I'd do that as a separate commit to keep this diff readable. --------- Co-authored-by: Luke Lau <[email protected]>
1 parent 896b5e5 commit 39f6d01

File tree

3 files changed

+40
-40
lines changed

3 files changed

+40
-40
lines changed

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2998,48 +2998,37 @@ MachineInstr *RISCVInstrInfo::convertToThreeAddress(MachineInstr &MI,
29982998
#undef CASE_WIDEOP_OPCODE_LMULS
29992999
#undef CASE_WIDEOP_OPCODE_COMMON
30003000

3001-
void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
3002-
MachineBasicBlock &MBB,
3003-
MachineBasicBlock::iterator II,
3004-
const DebugLoc &DL, Register DestReg,
3005-
int64_t Amount,
3006-
MachineInstr::MIFlag Flag) const {
3007-
assert(Amount > 0 && "There is no need to get VLEN scaled value.");
3008-
assert(Amount % 8 == 0 &&
3009-
"Reserve the stack by the multiple of one vector size.");
3010-
3001+
void RISCVInstrInfo::mulImm(MachineFunction &MF, MachineBasicBlock &MBB,
3002+
MachineBasicBlock::iterator II, const DebugLoc &DL,
3003+
Register DestReg, uint32_t Amount,
3004+
MachineInstr::MIFlag Flag) const {
30113005
MachineRegisterInfo &MRI = MF.getRegInfo();
3012-
assert(isInt<32>(Amount / 8) &&
3013-
"Expect the number of vector registers within 32-bits.");
3014-
uint32_t NumOfVReg = Amount / 8;
3015-
3016-
BuildMI(MBB, II, DL, get(RISCV::PseudoReadVLENB), DestReg).setMIFlag(Flag);
3017-
if (llvm::has_single_bit<uint32_t>(NumOfVReg)) {
3018-
uint32_t ShiftAmount = Log2_32(NumOfVReg);
3006+
if (llvm::has_single_bit<uint32_t>(Amount)) {
3007+
uint32_t ShiftAmount = Log2_32(Amount);
30193008
if (ShiftAmount == 0)
30203009
return;
30213010
BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)
30223011
.addReg(DestReg, RegState::Kill)
30233012
.addImm(ShiftAmount)
30243013
.setMIFlag(Flag);
30253014
} else if (STI.hasStdExtZba() &&
3026-
((NumOfVReg % 3 == 0 && isPowerOf2_64(NumOfVReg / 3)) ||
3027-
(NumOfVReg % 5 == 0 && isPowerOf2_64(NumOfVReg / 5)) ||
3028-
(NumOfVReg % 9 == 0 && isPowerOf2_64(NumOfVReg / 9)))) {
3015+
((Amount % 3 == 0 && isPowerOf2_64(Amount / 3)) ||
3016+
(Amount % 5 == 0 && isPowerOf2_64(Amount / 5)) ||
3017+
(Amount % 9 == 0 && isPowerOf2_64(Amount / 9)))) {
30293018
// We can use Zba SHXADD+SLLI instructions for multiply in some cases.
30303019
unsigned Opc;
30313020
uint32_t ShiftAmount;
3032-
if (NumOfVReg % 9 == 0) {
3021+
if (Amount % 9 == 0) {
30333022
Opc = RISCV::SH3ADD;
3034-
ShiftAmount = Log2_64(NumOfVReg / 9);
3035-
} else if (NumOfVReg % 5 == 0) {
3023+
ShiftAmount = Log2_64(Amount / 9);
3024+
} else if (Amount % 5 == 0) {
30363025
Opc = RISCV::SH2ADD;
3037-
ShiftAmount = Log2_64(NumOfVReg / 5);
3038-
} else if (NumOfVReg % 3 == 0) {
3026+
ShiftAmount = Log2_64(Amount / 5);
3027+
} else if (Amount % 3 == 0) {
30393028
Opc = RISCV::SH1ADD;
3040-
ShiftAmount = Log2_64(NumOfVReg / 3);
3029+
ShiftAmount = Log2_64(Amount / 3);
30413030
} else {
3042-
llvm_unreachable("Unexpected number of vregs");
3031+
llvm_unreachable("implied by if-clause");
30433032
}
30443033
if (ShiftAmount)
30453034
BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)
@@ -3050,9 +3039,9 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
30503039
.addReg(DestReg, RegState::Kill)
30513040
.addReg(DestReg)
30523041
.setMIFlag(Flag);
3053-
} else if (llvm::has_single_bit<uint32_t>(NumOfVReg - 1)) {
3042+
} else if (llvm::has_single_bit<uint32_t>(Amount - 1)) {
30543043
Register ScaledRegister = MRI.createVirtualRegister(&RISCV::GPRRegClass);
3055-
uint32_t ShiftAmount = Log2_32(NumOfVReg - 1);
3044+
uint32_t ShiftAmount = Log2_32(Amount - 1);
30563045
BuildMI(MBB, II, DL, get(RISCV::SLLI), ScaledRegister)
30573046
.addReg(DestReg)
30583047
.addImm(ShiftAmount)
@@ -3061,9 +3050,9 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
30613050
.addReg(ScaledRegister, RegState::Kill)
30623051
.addReg(DestReg, RegState::Kill)
30633052
.setMIFlag(Flag);
3064-
} else if (llvm::has_single_bit<uint32_t>(NumOfVReg + 1)) {
3053+
} else if (llvm::has_single_bit<uint32_t>(Amount + 1)) {
30653054
Register ScaledRegister = MRI.createVirtualRegister(&RISCV::GPRRegClass);
3066-
uint32_t ShiftAmount = Log2_32(NumOfVReg + 1);
3055+
uint32_t ShiftAmount = Log2_32(Amount + 1);
30673056
BuildMI(MBB, II, DL, get(RISCV::SLLI), ScaledRegister)
30683057
.addReg(DestReg)
30693058
.addImm(ShiftAmount)
@@ -3074,22 +3063,22 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
30743063
.setMIFlag(Flag);
30753064
} else if (STI.hasStdExtM() || STI.hasStdExtZmmul()) {
30763065
Register N = MRI.createVirtualRegister(&RISCV::GPRRegClass);
3077-
movImm(MBB, II, DL, N, NumOfVReg, Flag);
3066+
movImm(MBB, II, DL, N, Amount, Flag);
30783067
BuildMI(MBB, II, DL, get(RISCV::MUL), DestReg)
30793068
.addReg(DestReg, RegState::Kill)
30803069
.addReg(N, RegState::Kill)
30813070
.setMIFlag(Flag);
30823071
} else {
30833072
Register Acc;
30843073
uint32_t PrevShiftAmount = 0;
3085-
for (uint32_t ShiftAmount = 0; NumOfVReg >> ShiftAmount; ShiftAmount++) {
3086-
if (NumOfVReg & (1U << ShiftAmount)) {
3074+
for (uint32_t ShiftAmount = 0; Amount >> ShiftAmount; ShiftAmount++) {
3075+
if (Amount & (1U << ShiftAmount)) {
30873076
if (ShiftAmount)
30883077
BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)
30893078
.addReg(DestReg, RegState::Kill)
30903079
.addImm(ShiftAmount - PrevShiftAmount)
30913080
.setMIFlag(Flag);
3092-
if (NumOfVReg >> (ShiftAmount + 1)) {
3081+
if (Amount >> (ShiftAmount + 1)) {
30933082
// If we don't have an accmulator yet, create it and copy DestReg.
30943083
if (!Acc) {
30953084
Acc = MRI.createVirtualRegister(&RISCV::GPRRegClass);

llvm/lib/Target/RISCV/RISCVInstrInfo.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,10 +229,12 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
229229
unsigned OpIdx,
230230
const TargetRegisterInfo *TRI) const override;
231231

232-
void getVLENFactoredAmount(
233-
MachineFunction &MF, MachineBasicBlock &MBB,
234-
MachineBasicBlock::iterator II, const DebugLoc &DL, Register DestReg,
235-
int64_t Amount, MachineInstr::MIFlag Flag = MachineInstr::NoFlags) const;
232+
/// Generate code to multiply the value in DestReg by Amt - handles all
233+
/// the common optimizations for this idiom, and supports fallback for
234+
/// subtargets which don't support multiply instructions.
235+
void mulImm(MachineFunction &MF, MachineBasicBlock &MBB,
236+
MachineBasicBlock::iterator II, const DebugLoc &DL,
237+
Register DestReg, uint32_t Amt, MachineInstr::MIFlag Flag) const;
236238

237239
bool useMachineCombiner() const override { return true; }
238240

llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,16 @@ void RISCVRegisterInfo::adjustReg(MachineBasicBlock &MBB,
195195
Register ScratchReg = DestReg;
196196
if (DestReg == SrcReg)
197197
ScratchReg = MRI.createVirtualRegister(&RISCV::GPRRegClass);
198-
TII->getVLENFactoredAmount(MF, MBB, II, DL, ScratchReg, ScalableValue, Flag);
198+
199+
assert(ScalableValue > 0 && "There is no need to get VLEN scaled value.");
200+
assert(ScalableValue % 8 == 0 &&
201+
"Reserve the stack by the multiple of one vector size.");
202+
assert(isInt<32>(ScalableValue / 8) &&
203+
"Expect the number of vector registers within 32-bits.");
204+
uint32_t NumOfVReg = ScalableValue / 8;
205+
BuildMI(MBB, II, DL, TII->get(RISCV::PseudoReadVLENB), ScratchReg)
206+
.setMIFlag(Flag);
207+
TII->mulImm(MF, MBB, II, DL, ScratchReg, NumOfVReg, Flag);
199208
BuildMI(MBB, II, DL, TII->get(ScalableAdjOpc), DestReg)
200209
.addReg(SrcReg).addReg(ScratchReg, RegState::Kill)
201210
.setMIFlag(Flag);

0 commit comments

Comments
 (0)