-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[RISCV] Add copyPhysRegVector to extract common vector code out of copyPhysRegVector. #70497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-risc-v Author: Craig Topper (topperc) ChangesCall this method directly from each vector case with the correct I think I can reduce the number of operands to this new method, but Stacked on #70492 Full diff: https://github.com/llvm/llvm-project/pull/70497.diff 2 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index ad31b2974993c74..9e4e86100a2115b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -294,6 +294,99 @@ static bool isConvertibleToVMV_V_V(const RISCVSubtarget &STI,
return false;
}
+void RISCVInstrInfo::copyPhysRegVector(
+ MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
+ const DebugLoc &DL, MCRegister DstReg, MCRegister SrcReg, bool KillSrc,
+ unsigned Opc, unsigned NF, RISCVII::VLMUL LMul, unsigned SubRegIdx) const {
+ const TargetRegisterInfo *TRI = STI.getRegisterInfo();
+
+ bool UseVMV_V_V = false;
+ bool UseVMV_V_I = false;
+ MachineBasicBlock::const_iterator DefMBBI;
+ if (isConvertibleToVMV_V_V(STI, MBB, MBBI, DefMBBI, LMul)) {
+ UseVMV_V_V = true;
+ // We only need to handle LMUL = 1/2/4/8 here because we only define
+ // vector register classes for LMUL = 1/2/4/8.
+ unsigned VIOpc;
+ switch (LMul) {
+ default:
+ llvm_unreachable("Impossible LMUL for vector register copy.");
+ case RISCVII::LMUL_1:
+ Opc = RISCV::PseudoVMV_V_V_M1;
+ VIOpc = RISCV::PseudoVMV_V_I_M1;
+ break;
+ case RISCVII::LMUL_2:
+ Opc = RISCV::PseudoVMV_V_V_M2;
+ VIOpc = RISCV::PseudoVMV_V_I_M2;
+ break;
+ case RISCVII::LMUL_4:
+ Opc = RISCV::PseudoVMV_V_V_M4;
+ VIOpc = RISCV::PseudoVMV_V_I_M4;
+ break;
+ case RISCVII::LMUL_8:
+ Opc = RISCV::PseudoVMV_V_V_M8;
+ VIOpc = RISCV::PseudoVMV_V_I_M8;
+ break;
+ }
+
+ if (DefMBBI->getOpcode() == VIOpc) {
+ UseVMV_V_I = true;
+ Opc = VIOpc;
+ }
+ }
+
+ if (NF == 1) {
+ auto MIB = BuildMI(MBB, MBBI, DL, get(Opc), DstReg);
+ if (UseVMV_V_V)
+ MIB.addReg(DstReg, RegState::Undef);
+ if (UseVMV_V_I)
+ MIB = MIB.add(DefMBBI->getOperand(2));
+ else
+ MIB = MIB.addReg(SrcReg, getKillRegState(KillSrc));
+ if (UseVMV_V_V) {
+ const MCInstrDesc &Desc = DefMBBI->getDesc();
+ MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc))); // AVL
+ MIB.add(DefMBBI->getOperand(RISCVII::getSEWOpNum(Desc))); // SEW
+ MIB.addImm(0); // tu, mu
+ MIB.addReg(RISCV::VL, RegState::Implicit);
+ MIB.addReg(RISCV::VTYPE, RegState::Implicit);
+ }
+ } else {
+ int I = 0, End = NF, Incr = 1;
+ unsigned SrcEncoding = TRI->getEncodingValue(SrcReg);
+ unsigned DstEncoding = TRI->getEncodingValue(DstReg);
+ unsigned LMulVal;
+ bool Fractional;
+ std::tie(LMulVal, Fractional) = RISCVVType::decodeVLMUL(LMul);
+ assert(!Fractional && "It is impossible be fractional lmul here.");
+ if (forwardCopyWillClobberTuple(DstEncoding, SrcEncoding, NF * LMulVal)) {
+ I = NF - 1;
+ End = -1;
+ Incr = -1;
+ }
+
+ for (; I != End; I += Incr) {
+ auto MIB = BuildMI(MBB, MBBI, DL, get(Opc),
+ TRI->getSubReg(DstReg, SubRegIdx + I));
+ if (UseVMV_V_V)
+ MIB.addReg(TRI->getSubReg(DstReg, SubRegIdx + I), RegState::Undef);
+ if (UseVMV_V_I)
+ MIB = MIB.add(DefMBBI->getOperand(2));
+ else
+ MIB = MIB.addReg(TRI->getSubReg(SrcReg, SubRegIdx + I),
+ getKillRegState(KillSrc));
+ if (UseVMV_V_V) {
+ const MCInstrDesc &Desc = DefMBBI->getDesc();
+ MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc))); // AVL
+ MIB.add(DefMBBI->getOperand(RISCVII::getSEWOpNum(Desc))); // SEW
+ MIB.addImm(0); // tu, mu
+ MIB.addReg(RISCV::VL, RegState::Implicit);
+ MIB.addReg(RISCV::VTYPE, RegState::Implicit);
+ }
+ }
+ }
+}
+
void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
const DebugLoc &DL, MCRegister DstReg,
@@ -330,13 +423,8 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
return;
}
- // FPR->FPR copies and VR->VR copies.
- unsigned Opc;
- bool IsScalableVector = true;
- unsigned NF = 1;
- RISCVII::VLMUL LMul = RISCVII::LMUL_1;
- unsigned SubRegIdx = RISCV::sub_vrm1_0;
if (RISCV::FPR16RegClass.contains(DstReg, SrcReg)) {
+ unsigned Opc;
if (STI.hasStdExtZfh()) {
Opc = RISCV::FSGNJ_H;
} else {
@@ -350,176 +438,118 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
&RISCV::FPR32RegClass);
Opc = RISCV::FSGNJ_S;
}
- IsScalableVector = false;
- } else if (RISCV::FPR32RegClass.contains(DstReg, SrcReg)) {
- Opc = RISCV::FSGNJ_S;
- IsScalableVector = false;
- } else if (RISCV::FPR64RegClass.contains(DstReg, SrcReg)) {
- Opc = RISCV::FSGNJ_D;
- IsScalableVector = false;
- } else if (RISCV::VRRegClass.contains(DstReg, SrcReg)) {
- Opc = RISCV::VMV1R_V;
- LMul = RISCVII::LMUL_1;
- } else if (RISCV::VRM2RegClass.contains(DstReg, SrcReg)) {
- Opc = RISCV::VMV2R_V;
- LMul = RISCVII::LMUL_2;
- } else if (RISCV::VRM4RegClass.contains(DstReg, SrcReg)) {
- Opc = RISCV::VMV4R_V;
- LMul = RISCVII::LMUL_4;
- } else if (RISCV::VRM8RegClass.contains(DstReg, SrcReg)) {
- Opc = RISCV::VMV8R_V;
- LMul = RISCVII::LMUL_8;
- } else if (RISCV::VRN2M1RegClass.contains(DstReg, SrcReg)) {
- Opc = RISCV::VMV1R_V;
- SubRegIdx = RISCV::sub_vrm1_0;
- NF = 2;
- LMul = RISCVII::LMUL_1;
- } else if (RISCV::VRN2M2RegClass.contains(DstReg, SrcReg)) {
- Opc = RISCV::VMV2R_V;
- SubRegIdx = RISCV::sub_vrm2_0;
- NF = 2;
- LMul = RISCVII::LMUL_2;
- } else if (RISCV::VRN2M4RegClass.contains(DstReg, SrcReg)) {
- Opc = RISCV::VMV4R_V;
- SubRegIdx = RISCV::sub_vrm4_0;
- NF = 2;
- LMul = RISCVII::LMUL_4;
- } else if (RISCV::VRN3M1RegClass.contains(DstReg, SrcReg)) {
- Opc = RISCV::VMV1R_V;
- SubRegIdx = RISCV::sub_vrm1_0;
- NF = 3;
- LMul = RISCVII::LMUL_1;
- } else if (RISCV::VRN3M2RegClass.contains(DstReg, SrcReg)) {
- Opc = RISCV::VMV2R_V;
- SubRegIdx = RISCV::sub_vrm2_0;
- NF = 3;
- LMul = RISCVII::LMUL_2;
- } else if (RISCV::VRN4M1RegClass.contains(DstReg, SrcReg)) {
- Opc = RISCV::VMV1R_V;
- SubRegIdx = RISCV::sub_vrm1_0;
- NF = 4;
- LMul = RISCVII::LMUL_1;
- } else if (RISCV::VRN4M2RegClass.contains(DstReg, SrcReg)) {
- Opc = RISCV::VMV2R_V;
- SubRegIdx = RISCV::sub_vrm2_0;
- NF = 4;
- LMul = RISCVII::LMUL_2;
- } else if (RISCV::VRN5M1RegClass.contains(DstReg, SrcReg)) {
- Opc = RISCV::VMV1R_V;
- SubRegIdx = RISCV::sub_vrm1_0;
- NF = 5;
- LMul = RISCVII::LMUL_1;
- } else if (RISCV::VRN6M1RegClass.contains(DstReg, SrcReg)) {
- Opc = RISCV::VMV1R_V;
- SubRegIdx = RISCV::sub_vrm1_0;
- NF = 6;
- LMul = RISCVII::LMUL_1;
- } else if (RISCV::VRN7M1RegClass.contains(DstReg, SrcReg)) {
- Opc = RISCV::VMV1R_V;
- SubRegIdx = RISCV::sub_vrm1_0;
- NF = 7;
- LMul = RISCVII::LMUL_1;
- } else if (RISCV::VRN8M1RegClass.contains(DstReg, SrcReg)) {
- Opc = RISCV::VMV1R_V;
- SubRegIdx = RISCV::sub_vrm1_0;
- NF = 8;
- LMul = RISCVII::LMUL_1;
- } else {
- llvm_unreachable("Impossible reg-to-reg copy");
+ BuildMI(MBB, MBBI, DL, get(Opc), DstReg)
+ .addReg(SrcReg, getKillRegState(KillSrc))
+ .addReg(SrcReg, getKillRegState(KillSrc));
+ return;
}
- if (IsScalableVector) {
- bool UseVMV_V_V = false;
- bool UseVMV_V_I = false;
- MachineBasicBlock::const_iterator DefMBBI;
- if (isConvertibleToVMV_V_V(STI, MBB, MBBI, DefMBBI, LMul)) {
- UseVMV_V_V = true;
- // We only need to handle LMUL = 1/2/4/8 here because we only define
- // vector register classes for LMUL = 1/2/4/8.
- unsigned VIOpc;
- switch (LMul) {
- default:
- llvm_unreachable("Impossible LMUL for vector register copy.");
- case RISCVII::LMUL_1:
- Opc = RISCV::PseudoVMV_V_V_M1;
- VIOpc = RISCV::PseudoVMV_V_I_M1;
- break;
- case RISCVII::LMUL_2:
- Opc = RISCV::PseudoVMV_V_V_M2;
- VIOpc = RISCV::PseudoVMV_V_I_M2;
- break;
- case RISCVII::LMUL_4:
- Opc = RISCV::PseudoVMV_V_V_M4;
- VIOpc = RISCV::PseudoVMV_V_I_M4;
- break;
- case RISCVII::LMUL_8:
- Opc = RISCV::PseudoVMV_V_V_M8;
- VIOpc = RISCV::PseudoVMV_V_I_M8;
- break;
- }
-
- if (DefMBBI->getOpcode() == VIOpc) {
- UseVMV_V_I = true;
- Opc = VIOpc;
- }
- }
-
- if (NF == 1) {
- auto MIB = BuildMI(MBB, MBBI, DL, get(Opc), DstReg);
- if (UseVMV_V_V)
- MIB.addReg(DstReg, RegState::Undef);
- if (UseVMV_V_I)
- MIB = MIB.add(DefMBBI->getOperand(2));
- else
- MIB = MIB.addReg(SrcReg, getKillRegState(KillSrc));
- if (UseVMV_V_V) {
- const MCInstrDesc &Desc = DefMBBI->getDesc();
- MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc))); // AVL
- MIB.add(DefMBBI->getOperand(RISCVII::getSEWOpNum(Desc))); // SEW
- MIB.addImm(0); // tu, mu
- MIB.addReg(RISCV::VL, RegState::Implicit);
- MIB.addReg(RISCV::VTYPE, RegState::Implicit);
- }
- } else {
- int I = 0, End = NF, Incr = 1;
- unsigned SrcEncoding = TRI->getEncodingValue(SrcReg);
- unsigned DstEncoding = TRI->getEncodingValue(DstReg);
- unsigned LMulVal;
- bool Fractional;
- std::tie(LMulVal, Fractional) = RISCVVType::decodeVLMUL(LMul);
- assert(!Fractional && "It is impossible be fractional lmul here.");
- if (forwardCopyWillClobberTuple(DstEncoding, SrcEncoding, NF * LMulVal)) {
- I = NF - 1;
- End = -1;
- Incr = -1;
- }
+ if (RISCV::FPR32RegClass.contains(DstReg, SrcReg)) {
+ BuildMI(MBB, MBBI, DL, get(RISCV::FSGNJ_S), DstReg)
+ .addReg(SrcReg, getKillRegState(KillSrc))
+ .addReg(SrcReg, getKillRegState(KillSrc));
+ return;
+ }
- for (; I != End; I += Incr) {
- auto MIB = BuildMI(MBB, MBBI, DL, get(Opc),
- TRI->getSubReg(DstReg, SubRegIdx + I));
- if (UseVMV_V_V)
- MIB.addReg(TRI->getSubReg(DstReg, SubRegIdx + I),
- RegState::Undef);
- if (UseVMV_V_I)
- MIB = MIB.add(DefMBBI->getOperand(2));
- else
- MIB = MIB.addReg(TRI->getSubReg(SrcReg, SubRegIdx + I),
- getKillRegState(KillSrc));
- if (UseVMV_V_V) {
- const MCInstrDesc &Desc = DefMBBI->getDesc();
- MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc))); // AVL
- MIB.add(DefMBBI->getOperand(RISCVII::getSEWOpNum(Desc))); // SEW
- MIB.addImm(0); // tu, mu
- MIB.addReg(RISCV::VL, RegState::Implicit);
- MIB.addReg(RISCV::VTYPE, RegState::Implicit);
- }
- }
- }
- } else {
- BuildMI(MBB, MBBI, DL, get(Opc), DstReg)
+ if (RISCV::FPR64RegClass.contains(DstReg, SrcReg)) {
+ BuildMI(MBB, MBBI, DL, get(RISCV::FSGNJ_D), DstReg)
.addReg(SrcReg, getKillRegState(KillSrc))
.addReg(SrcReg, getKillRegState(KillSrc));
+ return;
+ }
+
+ // VR->VR copies.
+ if (RISCV::VRRegClass.contains(DstReg, SrcReg)) {
+ copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+ /*NF=*/1, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+ return;
+ }
+
+ if (RISCV::VRM2RegClass.contains(DstReg, SrcReg)) {
+ copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV2R_V,
+ /*NF=*/1, RISCVII::LMUL_2, RISCV::sub_vrm1_0);
+ return;
+ }
+
+ if (RISCV::VRM4RegClass.contains(DstReg, SrcReg)) {
+ copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV4R_V,
+ /*NF=*/1, RISCVII::LMUL_4, RISCV::sub_vrm1_0);
+ return;
+ }
+
+ if (RISCV::VRM8RegClass.contains(DstReg, SrcReg)) {
+ copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV8R_V,
+ /*NF=*/1, RISCVII::LMUL_8, RISCV::sub_vrm1_0);
+ return;
+ }
+
+ if (RISCV::VRN2M1RegClass.contains(DstReg, SrcReg)) {
+ copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+ /*NF=*/2, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+ return;
+ }
+
+ if (RISCV::VRN2M2RegClass.contains(DstReg, SrcReg)) {
+ copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV2R_V,
+ /*NF=*/2, RISCVII::LMUL_2, RISCV::sub_vrm2_0);
+ return;
+ }
+
+ if (RISCV::VRN2M4RegClass.contains(DstReg, SrcReg)) {
+ copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV4R_V,
+ /*NF=*/2, RISCVII::LMUL_4, RISCV::sub_vrm4_0);
+ return;
+ }
+
+ if (RISCV::VRN3M1RegClass.contains(DstReg, SrcReg)) {
+ copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+ /*NF=*/3, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+ return;
+ }
+
+ if (RISCV::VRN3M2RegClass.contains(DstReg, SrcReg)) {
+ copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV2R_V,
+ /*NF=*/3, RISCVII::LMUL_2, RISCV::sub_vrm2_0);
+ return;
+ }
+
+ if (RISCV::VRN4M1RegClass.contains(DstReg, SrcReg)) {
+ copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+ /*NF=*/4, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+ return;
+ }
+
+ if (RISCV::VRN4M2RegClass.contains(DstReg, SrcReg)) {
+ copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV2R_V,
+ /*NF=*/4, RISCVII::LMUL_2, RISCV::sub_vrm2_0);
+ return;
+ }
+
+ if (RISCV::VRN5M1RegClass.contains(DstReg, SrcReg)) {
+ copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+ /*NF=*/5, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+ return;
+ }
+
+ if (RISCV::VRN6M1RegClass.contains(DstReg, SrcReg)) {
+ copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+ /*NF=*/6, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+ return;
}
+
+ if (RISCV::VRN7M1RegClass.contains(DstReg, SrcReg)) {
+ copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+ /*NF=*/7, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+ return;
+ }
+
+ if (RISCV::VRN8M1RegClass.contains(DstReg, SrcReg)) {
+ copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV1R_V,
+ /*NF=*/8, RISCVII::LMUL_1, RISCV::sub_vrm1_0);
+ return;
+ }
+
+ llvm_unreachable("Impossible reg-to-reg copy");
}
void RISCVInstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 5584e5571c9bc35..4b93d44ed2d5d85 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -14,6 +14,7 @@
#define LLVM_LIB_TARGET_RISCV_RISCVINSTRINFO_H
#include "RISCVRegisterInfo.h"
+#include "MCTargetDesc/RISCVBaseInfo.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/IR/DiagnosticInfo.h"
@@ -63,6 +64,11 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
unsigned isStoreToStackSlot(const MachineInstr &MI, int &FrameIndex,
unsigned &MemBytes) const override;
+ void copyPhysRegVector(MachineBasicBlock &MBB,
+ MachineBasicBlock::iterator MBBI, const DebugLoc &DL,
+ MCRegister DstReg, MCRegister SrcReg, bool KillSrc,
+ unsigned Opc, unsigned NF, RISCVII::VLMUL LMul,
+ unsigned SubRegIdx) const;
void copyPhysReg(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
const DebugLoc &DL, MCRegister DstReg, MCRegister SrcReg,
bool KillSrc) const override;
|
|
||
if (RISCV::VRM2RegClass.contains(DstReg, SrcReg)) { | ||
copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RISCV::VMV2R_V, | ||
/*NF=*/1, RISCVII::LMUL_2, RISCV::sub_vrm1_0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note the SubRegIdx parameter is only used when NF != 1, so its value doesn't matter. The original code used sub_vrm1_0 for NF==1. There is no sub_vrm8_0 so we can't make it match the LMUL for NF=1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
That is much much easier to read. Thanks.
✅ With the latest revision this PR passed the C/C++ code formatter. |
…pyPhysRegVector. Call this method directly from each vector case with the correct arguments. This allows us to treat each type of copy as its own special case and not pass variables to a common merge point. This is similar to how AArch64 is structured. I think I can reduce the number of operands to this new method, but I'll do that as a follow up to make this patch easier to review.
88e4aaa
to
550e370
Compare
Call this method directly from each vector case with the correct
arguments. This allows us to treat each type of copy as its own
special case and not pass variables to a common merge point. This
is similar to how AArch64 is structured.
I think I can reduce the number of operands to this new method, but
I'll do that as a follow up to make this patch easier to review.
Stacked on #70492