Skip to content

[NVPTX] Remove Float register classes #140487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,6 @@ unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
Ret = (3 << 28);
} else if (RC == &NVPTX::Int64RegsRegClass) {
Ret = (4 << 28);
} else if (RC == &NVPTX::Float32RegsRegClass) {
Ret = (5 << 28);
} else if (RC == &NVPTX::Float64RegsRegClass) {
Ret = (6 << 28);
} else if (RC == &NVPTX::Int128RegsRegClass) {
Ret = (7 << 28);
} else {
Expand Down
11 changes: 4 additions & 7 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
addRegisterClass(MVT::f32, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::f64, &NVPTX::Int64RegsRegClass);
addRegisterClass(MVT::f16, &NVPTX::Int16RegsRegClass);
addRegisterClass(MVT::v2f16, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::bf16, &NVPTX::Int16RegsRegClass);
Expand Down Expand Up @@ -4992,24 +4992,21 @@ NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
case 'b':
return std::make_pair(0U, &NVPTX::Int1RegsRegClass);
case 'c':
return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
case 'h':
return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
case 'r':
case 'f':
return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
case 'l':
case 'N':
case 'd':
return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
case 'q': {
if (STI.getSmVersion() < 70)
report_fatal_error("Inline asm with 128 bit operands is only "
"supported for sm_70 and higher!");
return std::make_pair(0U, &NVPTX::Int128RegsRegClass);
}
case 'f':
return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
case 'd':
return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
}
}
return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
Expand Down
12 changes: 2 additions & 10 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,11 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
} else if (DestRC == &NVPTX::Int16RegsRegClass) {
Op = NVPTX::MOV16r;
} else if (DestRC == &NVPTX::Int32RegsRegClass) {
Op = (SrcRC == &NVPTX::Int32RegsRegClass ? NVPTX::IMOV32r
: NVPTX::BITCONVERT_32_F2I);
Op = NVPTX::IMOV32r;
} else if (DestRC == &NVPTX::Int64RegsRegClass) {
Op = (SrcRC == &NVPTX::Int64RegsRegClass ? NVPTX::IMOV64r
: NVPTX::BITCONVERT_64_F2I);
Op = NVPTX::IMOV64r;
} else if (DestRC == &NVPTX::Int128RegsRegClass) {
Op = NVPTX::IMOV128r;
} else if (DestRC == &NVPTX::Float32RegsRegClass) {
Op = (SrcRC == &NVPTX::Float32RegsRegClass ? NVPTX::FMOV32r
: NVPTX::BITCONVERT_32_I2F);
} else if (DestRC == &NVPTX::Float64RegsRegClass) {
Op = (SrcRC == &NVPTX::Float64RegsRegClass ? NVPTX::FMOV64r
: NVPTX::BITCONVERT_64_I2F);
} else {
llvm_unreachable("Bad register copy");
}
Expand Down
8 changes: 0 additions & 8 deletions llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ using namespace llvm;

namespace llvm {
StringRef getNVPTXRegClassName(TargetRegisterClass const *RC) {
if (RC == &NVPTX::Float32RegsRegClass)
return ".b32";
if (RC == &NVPTX::Float64RegsRegClass)
return ".b64";
if (RC == &NVPTX::Int128RegsRegClass)
return ".b128";
if (RC == &NVPTX::Int64RegsRegClass)
Expand Down Expand Up @@ -63,10 +59,6 @@ StringRef getNVPTXRegClassName(TargetRegisterClass const *RC) {
}

StringRef getNVPTXRegClassStr(TargetRegisterClass const *RC) {
if (RC == &NVPTX::Float32RegsRegClass)
return "%f";
if (RC == &NVPTX::Float64RegsRegClass)
return "%fd";
if (RC == &NVPTX::Int128RegsRegClass)
return "%rq";
if (RC == &NVPTX::Int64RegsRegClass)
Expand Down
12 changes: 6 additions & 6 deletions llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ foreach i = 0...4 in {
def RQ#i : NVPTXReg<"%rq"#i>; // 128-bit
def H#i : NVPTXReg<"%h"#i>; // 16-bit float
def HH#i : NVPTXReg<"%hh"#i>; // 2x16-bit float
def F#i : NVPTXReg<"%f"#i>; // 32-bit float
def FL#i : NVPTXReg<"%fd"#i>; // 64-bit float

// Arguments
def ia#i : NVPTXReg<"%ia"#i>;
Expand All @@ -59,14 +57,13 @@ foreach i = 0...31 in {
//===----------------------------------------------------------------------===//
def Int1Regs : NVPTXRegClass<[i1], 8, (add (sequence "P%u", 0, 4))>;
def Int16Regs : NVPTXRegClass<[i16, f16, bf16], 16, (add (sequence "RS%u", 0, 4))>;
def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8], 32,
def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8, f32], 32,
(add (sequence "R%u", 0, 4),
VRFrame32, VRFrameLocal32)>;
def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
def Int64Regs : NVPTXRegClass<[i64, f64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
// 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
def Int128Regs : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;
def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>;
def Float64Regs : NVPTXRegClass<[f64], 64, (add (sequence "FL%u", 0, 4))>;

def Int32ArgRegs : NVPTXRegClass<[i32], 32, (add (sequence "ia%u", 0, 4))>;
def Int64ArgRegs : NVPTXRegClass<[i64], 64, (add (sequence "la%u", 0, 4))>;
def Float32ArgRegs : NVPTXRegClass<[f32], 32, (add (sequence "fa%u", 0, 4))>;
Expand All @@ -75,3 +72,6 @@ def Float64ArgRegs : NVPTXRegClass<[f64], 64, (add (sequence "da%u", 0, 4))>;
// Read NVPTXRegisterInfo.cpp to see how VRFrame and VRDepot are used.
def SpecialRegs : NVPTXRegClass<[i32], 32, (add VRFrame32, VRFrameLocal32, VRDepot,
(sequence "ENVREG%u", 0, 31))>;

defvar Float32Regs = Int32Regs;
defvar Float64Regs = Int64Regs;
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
---
name: test
registers:
- { id: 0, class: float32regs }
- { id: 1, class: float32regs }
- { id: 0, class: int32regs }
- { id: 1, class: int32regs }
body: |
bb.0.entry:
%0 = LD_f32 0, 4, 1, 2, 32, &test_param_0, 0
Expand Down
36 changes: 18 additions & 18 deletions llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,24 @@
---
name: test
registers:
- { id: 0, class: float32regs }
- { id: 1, class: float64regs }
- { id: 0, class: int32regs }
- { id: 1, class: int64regs }
- { id: 2, class: int32regs }
- { id: 3, class: float64regs }
- { id: 4, class: float32regs }
- { id: 5, class: float32regs }
- { id: 6, class: float32regs }
- { id: 7, class: float32regs }
- { id: 3, class: int64regs }
- { id: 4, class: int32regs }
- { id: 5, class: int32regs }
- { id: 6, class: int32regs }
- { id: 7, class: int32regs }
body: |
bb.0.entry:
%0 = LD_f32 0, 0, 4, 2, 32, &test_param_0, 0
%1 = CVT_f64_f32 %0, 0
%2 = LD_i32 0, 0, 4, 0, 32, &test_param_1, 0
; CHECK: %3:float64regs = FADD_rnf64ri %1, double 3.250000e+00
; CHECK: %3:int64regs = FADD_rnf64ri %1, double 3.250000e+00
%3 = FADD_rnf64ri %1, double 3.250000e+00
%4 = CVT_f32_f64 %3, 5
%5 = CVT_f32_s32 %2, 5
; CHECK: %6:float32regs = FADD_rnf32ri %5, float 6.250000e+00
; CHECK: %6:int32regs = FADD_rnf32ri %5, float 6.250000e+00
%6 = FADD_rnf32ri %5, float 6.250000e+00
%7 = FMUL_rnf32rr %6, %4
StoreRetvalF32 %7, 0
Expand All @@ -56,24 +56,24 @@ body: |
---
name: test2
registers:
- { id: 0, class: float32regs }
- { id: 1, class: float64regs }
- { id: 0, class: int32regs }
- { id: 1, class: int64regs }
- { id: 2, class: int32regs }
- { id: 3, class: float64regs }
- { id: 4, class: float32regs }
- { id: 5, class: float32regs }
- { id: 6, class: float32regs }
- { id: 7, class: float32regs }
- { id: 3, class: int64regs }
- { id: 4, class: int32regs }
- { id: 5, class: int32regs }
- { id: 6, class: int32regs }
- { id: 7, class: int32regs }
body: |
bb.0.entry:
%0 = LD_f32 0, 0, 4, 2, 32, &test2_param_0, 0
%1 = CVT_f64_f32 %0, 0
%2 = LD_i32 0, 0, 4, 0, 32, &test2_param_1, 0
; CHECK: %3:float64regs = FADD_rnf64ri %1, double 0x7FF8000000000000
; CHECK: %3:int64regs = FADD_rnf64ri %1, double 0x7FF8000000000000
%3 = FADD_rnf64ri %1, double 0x7FF8000000000000
%4 = CVT_f32_f64 %3, 5
%5 = CVT_f32_s32 %2, 5
; CHECK: %6:float32regs = FADD_rnf32ri %5, float 0x7FF8000000000000
; CHECK: %6:int32regs = FADD_rnf32ri %5, float 0x7FF8000000000000
%6 = FADD_rnf32ri %5, float 0x7FF8000000000000
%7 = FMUL_rnf32rr %6, %4
StoreRetvalF32 %7, 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
---
name: test
registers:
- { id: 0, class: float32regs }
- { id: 1, class: float32regs }
- { id: 0, class: int32regs }
- { id: 1, class: int32regs }
body: |
bb.0.entry:
%0 = LD_f32 0, 4, 1, 2, 32, &test_param_0, 0
Expand Down
106 changes: 53 additions & 53 deletions llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
Original file line number Diff line number Diff line change
Expand Up @@ -45,36 +45,36 @@ define half @fh(ptr %p) {
; ENABLED-LABEL: fh(
; ENABLED: {
; ENABLED-NEXT: .reg .b16 %rs<10>;
; ENABLED-NEXT: .reg .b32 %f<13>;
; ENABLED-NEXT: .reg .b32 %r<13>;
; ENABLED-NEXT: .reg .b64 %rd<2>;
; ENABLED-EMPTY:
; ENABLED-NEXT: // %bb.0:
; ENABLED-NEXT: ld.param.b64 %rd1, [fh_param_0];
; ENABLED-NEXT: ld.v4.b16 {%rs1, %rs2, %rs3, %rs4}, [%rd1];
; ENABLED-NEXT: ld.b16 %rs5, [%rd1+8];
; ENABLED-NEXT: cvt.f32.f16 %f1, %rs2;
; ENABLED-NEXT: cvt.f32.f16 %f2, %rs1;
; ENABLED-NEXT: add.rn.f32 %f3, %f2, %f1;
; ENABLED-NEXT: cvt.rn.f16.f32 %rs6, %f3;
; ENABLED-NEXT: cvt.f32.f16 %f4, %rs4;
; ENABLED-NEXT: cvt.f32.f16 %f5, %rs3;
; ENABLED-NEXT: add.rn.f32 %f6, %f5, %f4;
; ENABLED-NEXT: cvt.rn.f16.f32 %rs7, %f6;
; ENABLED-NEXT: cvt.f32.f16 %f7, %rs7;
; ENABLED-NEXT: cvt.f32.f16 %f8, %rs6;
; ENABLED-NEXT: add.rn.f32 %f9, %f8, %f7;
; ENABLED-NEXT: cvt.rn.f16.f32 %rs8, %f9;
; ENABLED-NEXT: cvt.f32.f16 %f10, %rs8;
; ENABLED-NEXT: cvt.f32.f16 %f11, %rs5;
; ENABLED-NEXT: add.rn.f32 %f12, %f10, %f11;
; ENABLED-NEXT: cvt.rn.f16.f32 %rs9, %f12;
; ENABLED-NEXT: cvt.f32.f16 %r1, %rs2;
; ENABLED-NEXT: cvt.f32.f16 %r2, %rs1;
; ENABLED-NEXT: add.rn.f32 %r3, %r2, %r1;
; ENABLED-NEXT: cvt.rn.f16.f32 %rs6, %r3;
; ENABLED-NEXT: cvt.f32.f16 %r4, %rs4;
; ENABLED-NEXT: cvt.f32.f16 %r5, %rs3;
; ENABLED-NEXT: add.rn.f32 %r6, %r5, %r4;
; ENABLED-NEXT: cvt.rn.f16.f32 %rs7, %r6;
; ENABLED-NEXT: cvt.f32.f16 %r7, %rs7;
; ENABLED-NEXT: cvt.f32.f16 %r8, %rs6;
; ENABLED-NEXT: add.rn.f32 %r9, %r8, %r7;
; ENABLED-NEXT: cvt.rn.f16.f32 %rs8, %r9;
; ENABLED-NEXT: cvt.f32.f16 %r10, %rs8;
; ENABLED-NEXT: cvt.f32.f16 %r11, %rs5;
; ENABLED-NEXT: add.rn.f32 %r12, %r10, %r11;
; ENABLED-NEXT: cvt.rn.f16.f32 %rs9, %r12;
; ENABLED-NEXT: st.param.b16 [func_retval0], %rs9;
; ENABLED-NEXT: ret;
;
; DISABLED-LABEL: fh(
; DISABLED: {
; DISABLED-NEXT: .reg .b16 %rs<10>;
; DISABLED-NEXT: .reg .b32 %f<13>;
; DISABLED-NEXT: .reg .b32 %r<13>;
; DISABLED-NEXT: .reg .b64 %rd<2>;
; DISABLED-EMPTY:
; DISABLED-NEXT: // %bb.0:
Expand All @@ -84,22 +84,22 @@ define half @fh(ptr %p) {
; DISABLED-NEXT: ld.b16 %rs3, [%rd1+4];
; DISABLED-NEXT: ld.b16 %rs4, [%rd1+6];
; DISABLED-NEXT: ld.b16 %rs5, [%rd1+8];
; DISABLED-NEXT: cvt.f32.f16 %f1, %rs2;
; DISABLED-NEXT: cvt.f32.f16 %f2, %rs1;
; DISABLED-NEXT: add.rn.f32 %f3, %f2, %f1;
; DISABLED-NEXT: cvt.rn.f16.f32 %rs6, %f3;
; DISABLED-NEXT: cvt.f32.f16 %f4, %rs4;
; DISABLED-NEXT: cvt.f32.f16 %f5, %rs3;
; DISABLED-NEXT: add.rn.f32 %f6, %f5, %f4;
; DISABLED-NEXT: cvt.rn.f16.f32 %rs7, %f6;
; DISABLED-NEXT: cvt.f32.f16 %f7, %rs7;
; DISABLED-NEXT: cvt.f32.f16 %f8, %rs6;
; DISABLED-NEXT: add.rn.f32 %f9, %f8, %f7;
; DISABLED-NEXT: cvt.rn.f16.f32 %rs8, %f9;
; DISABLED-NEXT: cvt.f32.f16 %f10, %rs8;
; DISABLED-NEXT: cvt.f32.f16 %f11, %rs5;
; DISABLED-NEXT: add.rn.f32 %f12, %f10, %f11;
; DISABLED-NEXT: cvt.rn.f16.f32 %rs9, %f12;
; DISABLED-NEXT: cvt.f32.f16 %r1, %rs2;
; DISABLED-NEXT: cvt.f32.f16 %r2, %rs1;
; DISABLED-NEXT: add.rn.f32 %r3, %r2, %r1;
; DISABLED-NEXT: cvt.rn.f16.f32 %rs6, %r3;
; DISABLED-NEXT: cvt.f32.f16 %r4, %rs4;
; DISABLED-NEXT: cvt.f32.f16 %r5, %rs3;
; DISABLED-NEXT: add.rn.f32 %r6, %r5, %r4;
; DISABLED-NEXT: cvt.rn.f16.f32 %rs7, %r6;
; DISABLED-NEXT: cvt.f32.f16 %r7, %rs7;
; DISABLED-NEXT: cvt.f32.f16 %r8, %rs6;
; DISABLED-NEXT: add.rn.f32 %r9, %r8, %r7;
; DISABLED-NEXT: cvt.rn.f16.f32 %rs8, %r9;
; DISABLED-NEXT: cvt.f32.f16 %r10, %rs8;
; DISABLED-NEXT: cvt.f32.f16 %r11, %rs5;
; DISABLED-NEXT: add.rn.f32 %r12, %r10, %r11;
; DISABLED-NEXT: cvt.rn.f16.f32 %rs9, %r12;
; DISABLED-NEXT: st.param.b16 [func_retval0], %rs9;
; DISABLED-NEXT: ret;
%p.1 = getelementptr half, ptr %p, i32 1
Expand All @@ -121,37 +121,37 @@ define half @fh(ptr %p) {
define float @ff(ptr %p) {
; ENABLED-LABEL: ff(
; ENABLED: {
; ENABLED-NEXT: .reg .b32 %f<10>;
; ENABLED-NEXT: .reg .b32 %r<10>;
; ENABLED-NEXT: .reg .b64 %rd<2>;
; ENABLED-EMPTY:
; ENABLED-NEXT: // %bb.0:
; ENABLED-NEXT: ld.param.b64 %rd1, [ff_param_0];
; ENABLED-NEXT: ld.v4.b32 {%f1, %f2, %f3, %f4}, [%rd1];
; ENABLED-NEXT: ld.b32 %f5, [%rd1+16];
; ENABLED-NEXT: add.rn.f32 %f6, %f1, %f2;
; ENABLED-NEXT: add.rn.f32 %f7, %f3, %f4;
; ENABLED-NEXT: add.rn.f32 %f8, %f6, %f7;
; ENABLED-NEXT: add.rn.f32 %f9, %f8, %f5;
; ENABLED-NEXT: st.param.b32 [func_retval0], %f9;
; ENABLED-NEXT: ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1];
; ENABLED-NEXT: ld.b32 %r5, [%rd1+16];
; ENABLED-NEXT: add.rn.f32 %r6, %r1, %r2;
; ENABLED-NEXT: add.rn.f32 %r7, %r3, %r4;
; ENABLED-NEXT: add.rn.f32 %r8, %r6, %r7;
; ENABLED-NEXT: add.rn.f32 %r9, %r8, %r5;
; ENABLED-NEXT: st.param.b32 [func_retval0], %r9;
; ENABLED-NEXT: ret;
;
; DISABLED-LABEL: ff(
; DISABLED: {
; DISABLED-NEXT: .reg .b32 %f<10>;
; DISABLED-NEXT: .reg .b32 %r<10>;
; DISABLED-NEXT: .reg .b64 %rd<2>;
; DISABLED-EMPTY:
; DISABLED-NEXT: // %bb.0:
; DISABLED-NEXT: ld.param.b64 %rd1, [ff_param_0];
; DISABLED-NEXT: ld.b32 %f1, [%rd1];
; DISABLED-NEXT: ld.b32 %f2, [%rd1+4];
; DISABLED-NEXT: ld.b32 %f3, [%rd1+8];
; DISABLED-NEXT: ld.b32 %f4, [%rd1+12];
; DISABLED-NEXT: ld.b32 %f5, [%rd1+16];
; DISABLED-NEXT: add.rn.f32 %f6, %f1, %f2;
; DISABLED-NEXT: add.rn.f32 %f7, %f3, %f4;
; DISABLED-NEXT: add.rn.f32 %f8, %f6, %f7;
; DISABLED-NEXT: add.rn.f32 %f9, %f8, %f5;
; DISABLED-NEXT: st.param.b32 [func_retval0], %f9;
; DISABLED-NEXT: ld.b32 %r1, [%rd1];
; DISABLED-NEXT: ld.b32 %r2, [%rd1+4];
; DISABLED-NEXT: ld.b32 %r3, [%rd1+8];
; DISABLED-NEXT: ld.b32 %r4, [%rd1+12];
; DISABLED-NEXT: ld.b32 %r5, [%rd1+16];
; DISABLED-NEXT: add.rn.f32 %r6, %r1, %r2;
; DISABLED-NEXT: add.rn.f32 %r7, %r3, %r4;
; DISABLED-NEXT: add.rn.f32 %r8, %r6, %r7;
; DISABLED-NEXT: add.rn.f32 %r9, %r8, %r5;
; DISABLED-NEXT: st.param.b32 [func_retval0], %r9;
; DISABLED-NEXT: ret;
%p.1 = getelementptr float, ptr %p, i32 1
%p.2 = getelementptr float, ptr %p, i32 2
Expand Down
Loading
Loading