@@ -819,92 +819,87 @@ void RISCVInstrInfo::loadRegFromStackSlot(
819
819
.setMIFlag (Flags);
820
820
}
821
821
}
822
+ std::optional<unsigned > getFoldedOpcode (MachineFunction &MF, MachineInstr &MI,
823
+ ArrayRef<unsigned > Ops,
824
+ const RISCVSubtarget &ST) {
822
825
823
- MachineInstr *RISCVInstrInfo::foldMemoryOperandImpl (
824
- MachineFunction &MF, MachineInstr &MI, ArrayRef<unsigned > Ops,
825
- MachineBasicBlock::iterator InsertPt, int FrameIndex, LiveIntervals *LIS,
826
- VirtRegMap *VRM) const {
827
826
// The below optimizations narrow the load so they are only valid for little
828
827
// endian.
829
828
// TODO: Support big endian by adding an offset into the frame object?
830
829
if (MF.getDataLayout ().isBigEndian ())
831
- return nullptr ;
830
+ return std::nullopt ;
832
831
833
832
// Fold load from stack followed by sext.b/sext.h/sext.w/zext.b/zext.h/zext.w.
834
833
if (Ops.size () != 1 || Ops[0 ] != 1 )
835
- return nullptr ;
834
+ return std::nullopt ;
836
835
837
- unsigned LoadOpc;
838
836
switch (MI.getOpcode ()) {
839
837
default :
840
- if (RISCV::isSEXT_W (MI)) {
841
- LoadOpc = RISCV::LW;
842
- break ;
843
- }
844
- if (RISCV::isZEXT_W (MI)) {
845
- LoadOpc = RISCV::LWU;
846
- break ;
847
- }
848
- if (RISCV::isZEXT_B (MI)) {
849
- LoadOpc = RISCV::LBU;
850
- break ;
851
- }
852
- if (RISCV::getRVVMCOpcode (MI.getOpcode ()) == RISCV::VMV_X_S) {
853
- unsigned Log2SEW =
854
- MI.getOperand (RISCVII::getSEWOpNum (MI.getDesc ())).getImm ();
855
- if (STI.getXLen () < (1U << Log2SEW))
856
- return nullptr ;
857
- switch (Log2SEW) {
858
- case 3 :
859
- LoadOpc = RISCV::LB;
860
- break ;
861
- case 4 :
862
- LoadOpc = RISCV::LH;
863
- break ;
864
- case 5 :
865
- LoadOpc = RISCV::LW;
866
- break ;
867
- case 6 :
868
- LoadOpc = RISCV::LD;
869
- break ;
870
- default :
871
- llvm_unreachable (" Unexpected SEW" );
872
- }
873
- break ;
874
- }
875
- if (RISCV::getRVVMCOpcode (MI.getOpcode ()) == RISCV::VFMV_F_S) {
876
- unsigned Log2SEW =
877
- MI.getOperand (RISCVII::getSEWOpNum (MI.getDesc ())).getImm ();
878
- switch (Log2SEW) {
879
- case 4 :
880
- LoadOpc = RISCV::FLH;
881
- break ;
882
- case 5 :
883
- LoadOpc = RISCV::FLW;
884
- break ;
885
- case 6 :
886
- LoadOpc = RISCV::FLD;
887
- break ;
888
- default :
889
- llvm_unreachable (" Unexpected SEW" );
890
- }
891
- break ;
892
- }
893
- return nullptr ;
894
- case RISCV::SEXT_H:
895
- LoadOpc = RISCV::LH;
838
+ if (RISCV::isSEXT_W (MI))
839
+ return RISCV::LW;
840
+ if (RISCV::isZEXT_W (MI))
841
+ return RISCV::LWU;
842
+ if (RISCV::isZEXT_B (MI))
843
+ return RISCV::LBU;
896
844
break ;
845
+ case RISCV::SEXT_H:
846
+ return RISCV::LH;
897
847
case RISCV::SEXT_B:
898
- LoadOpc = RISCV::LB;
899
- break ;
848
+ return RISCV::LB;
900
849
case RISCV::ZEXT_H_RV32:
901
850
case RISCV::ZEXT_H_RV64:
902
- LoadOpc = RISCV::LHU;
903
- break ;
851
+ return RISCV::LHU;
852
+ }
853
+
854
+ switch (RISCV::getRVVMCOpcode (MI.getOpcode ())) {
855
+ default :
856
+ return std::nullopt;
857
+ case RISCV::VMV_X_S: {
858
+ unsigned Log2SEW =
859
+ MI.getOperand (RISCVII::getSEWOpNum (MI.getDesc ())).getImm ();
860
+ if (ST.getXLen () < (1U << Log2SEW))
861
+ return std::nullopt;
862
+ switch (Log2SEW) {
863
+ case 3 :
864
+ return RISCV::LB;
865
+ case 4 :
866
+ return RISCV::LH;
867
+ case 5 :
868
+ return RISCV::LW;
869
+ case 6 :
870
+ return RISCV::LD;
871
+ default :
872
+ llvm_unreachable (" Unexpected SEW" );
873
+ }
904
874
}
875
+ case RISCV::VFMV_F_S: {
876
+ unsigned Log2SEW =
877
+ MI.getOperand (RISCVII::getSEWOpNum (MI.getDesc ())).getImm ();
878
+ switch (Log2SEW) {
879
+ case 4 :
880
+ return RISCV::FLH;
881
+ case 5 :
882
+ return RISCV::FLW;
883
+ case 6 :
884
+ return RISCV::FLD;
885
+ default :
886
+ llvm_unreachable (" Unexpected SEW" );
887
+ }
888
+ }
889
+ }
890
+ }
905
891
892
+ // This is the version used during inline spilling
893
+ MachineInstr *RISCVInstrInfo::foldMemoryOperandImpl (
894
+ MachineFunction &MF, MachineInstr &MI, ArrayRef<unsigned > Ops,
895
+ MachineBasicBlock::iterator InsertPt, int FrameIndex, LiveIntervals *LIS,
896
+ VirtRegMap *VRM) const {
897
+
898
+ std::optional<unsigned > LoadOpc = getFoldedOpcode (MF, MI, Ops, STI);
899
+ if (!LoadOpc)
900
+ return nullptr ;
906
901
Register DstReg = MI.getOperand (0 ).getReg ();
907
- return BuildMI (*MI.getParent (), InsertPt, MI.getDebugLoc (), get (LoadOpc),
902
+ return BuildMI (*MI.getParent (), InsertPt, MI.getDebugLoc (), get (* LoadOpc),
908
903
DstReg)
909
904
.addFrameIndex (FrameIndex)
910
905
.addImm (0 );
0 commit comments