Skip to content

[NVPTX] Cleanup ISel for selp.* #135065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 44 additions & 66 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -725,62 +725,40 @@ def : Pat<(v2f16 (build_vector (f16 (fpround_oneuse f32:$lo)),
// selp instructions that don't have any pattern matches; we explicitly use
// them within this file.
let hasSideEffects = false in {
multiclass SELP<string TypeStr, RegisterClass RC, Operand ImmCls> {
def rr : NVPTXInst<(outs RC:$dst),
(ins RC:$a, RC:$b, Int1Regs:$p),
!strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
def ri : NVPTXInst<(outs RC:$dst),
(ins RC:$a, ImmCls:$b, Int1Regs:$p),
!strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
def ir : NVPTXInst<(outs RC:$dst),
(ins ImmCls:$a, RC:$b, Int1Regs:$p),
!strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
def ii : NVPTXInst<(outs RC:$dst),
(ins ImmCls:$a, ImmCls:$b, Int1Regs:$p),
!strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
}

multiclass SELP_PATTERN<string TypeStr, ValueType T, RegisterClass RC,
Operand ImmCls, SDNode ImmNode> {
multiclass SELP_PATTERN<string TypeStr, RegTyInfo t> {
defvar asm_str = "selp." # TypeStr # " \t$dst, $a, $b, $p;";
def rr :
NVPTXInst<(outs RC:$dst),
(ins RC:$a, RC:$b, Int1Regs:$p),
!strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
[(set T:$dst, (select i1:$p, T:$a, T:$b))]>;
NVPTXInst<(outs t.RC:$dst),
(ins t.RC:$a, t.RC:$b, Int1Regs:$p),
asm_str,
[(set t.Ty:$dst, (select i1:$p, t.Ty:$a, t.Ty:$b))]>;
def ri :
NVPTXInst<(outs RC:$dst),
(ins RC:$a, ImmCls:$b, Int1Regs:$p),
!strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
[(set T:$dst, (select i1:$p, T:$a, (T ImmNode:$b)))]>;
NVPTXInst<(outs t.RC:$dst),
(ins t.RC:$a, t.Imm:$b, Int1Regs:$p),
asm_str,
[(set t.Ty:$dst, (select i1:$p, t.Ty:$a, t.ImmNode:$b))]>;
def ir :
NVPTXInst<(outs RC:$dst),
(ins ImmCls:$a, RC:$b, Int1Regs:$p),
!strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
[(set T:$dst, (select i1:$p, ImmNode:$a, T:$b))]>;
NVPTXInst<(outs t.RC:$dst),
(ins t.Imm:$a, t.RC:$b, Int1Regs:$p),
asm_str,
[(set t.Ty:$dst, (select i1:$p, t.ImmNode:$a, t.Ty:$b))]>;
def ii :
NVPTXInst<(outs RC:$dst),
(ins ImmCls:$a, ImmCls:$b, Int1Regs:$p),
!strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
[(set T:$dst, (select i1:$p, ImmNode:$a, ImmNode:$b))]>;
NVPTXInst<(outs t.RC:$dst),
(ins t.Imm:$a, t.Imm:$b, Int1Regs:$p),
asm_str,
[(set t.Ty:$dst, (select i1:$p, t.ImmNode:$a, t.ImmNode:$b))]>;
}
}

// Don't pattern match on selp.{s,u}{16,32,64} -- selp.b{16,32,64} is just as
// good.
defm SELP_b16 : SELP_PATTERN<"b16", i16, Int16Regs, i16imm, imm>;
defm SELP_s16 : SELP<"s16", Int16Regs, i16imm>;
defm SELP_u16 : SELP<"u16", Int16Regs, i16imm>;
defm SELP_b32 : SELP_PATTERN<"b32", i32, Int32Regs, i32imm, imm>;
defm SELP_s32 : SELP<"s32", Int32Regs, i32imm>;
defm SELP_u32 : SELP<"u32", Int32Regs, i32imm>;
defm SELP_b64 : SELP_PATTERN<"b64", i64, Int64Regs, i64imm, imm>;
defm SELP_s64 : SELP<"s64", Int64Regs, i64imm>;
defm SELP_u64 : SELP<"u64", Int64Regs, i64imm>;
defm SELP_f16 : SELP_PATTERN<"b16", f16, Int16Regs, f16imm, fpimm>;
defm SELP_bf16 : SELP_PATTERN<"b16", bf16, Int16Regs, bf16imm, fpimm>;

defm SELP_f32 : SELP_PATTERN<"f32", f32, Float32Regs, f32imm, fpimm>;
defm SELP_f64 : SELP_PATTERN<"f64", f64, Float64Regs, f64imm, fpimm>;
defm SELP_b16 : SELP_PATTERN<"b16", I16RT>;
defm SELP_b32 : SELP_PATTERN<"b32", I32RT>;
defm SELP_b64 : SELP_PATTERN<"b64", I64RT>;
defm SELP_f16 : SELP_PATTERN<"b16", F16RT>;
defm SELP_bf16 : SELP_PATTERN<"b16", BF16RT>;
defm SELP_f32 : SELP_PATTERN<"f32", F32RT>;
defm SELP_f64 : SELP_PATTERN<"f64", F64RT>;

// This does not work as tablegen fails to infer the type of 'imm'.
// def v2f16imm : Operand<v2f16>;
Expand Down Expand Up @@ -2023,9 +2001,9 @@ def: Pat<(setne (i16 (and (trunc (bfe Int32Regs:$a, imm:$oa, 8)), 255)),

// i1 compare -> i32
def : Pat<(i32 (setne i1:$a, i1:$b)),
(SELP_u32ii -1, 0, (XORb1rr $a, $b))>;
(SELP_b32ii -1, 0, (XORb1rr $a, $b))>;
def : Pat<(i32 (setne i1:$a, i1:$b)),
(SELP_u32ii 0, -1, (XORb1rr $a, $b))>;
(SELP_b32ii 0, -1, (XORb1rr $a, $b))>;



Expand Down Expand Up @@ -2690,7 +2668,7 @@ foreach ta = [v2f16, v2bf16, v2i16, v4i8, i32] in {

// sint -> f16
def : Pat<(f16 (sint_to_fp i1:$a)),
(CVT_f16_s32 (SELP_s32ii -1, 0, $a), CvtRN)>;
(CVT_f16_s32 (SELP_b32ii -1, 0, $a), CvtRN)>;
def : Pat<(f16 (sint_to_fp Int16Regs:$a)),
(CVT_f16_s16 $a, CvtRN)>;
def : Pat<(f16 (sint_to_fp i32:$a)),
Expand All @@ -2700,7 +2678,7 @@ def : Pat<(f16 (sint_to_fp i64:$a)),

// uint -> f16
def : Pat<(f16 (uint_to_fp i1:$a)),
(CVT_f16_u32 (SELP_u32ii 1, 0, $a), CvtRN)>;
(CVT_f16_u32 (SELP_b32ii 1, 0, $a), CvtRN)>;
def : Pat<(f16 (uint_to_fp Int16Regs:$a)),
(CVT_f16_u16 $a, CvtRN)>;
def : Pat<(f16 (uint_to_fp i32:$a)),
Expand All @@ -2710,7 +2688,7 @@ def : Pat<(f16 (uint_to_fp i64:$a)),

// sint -> bf16
def : Pat<(bf16 (sint_to_fp i1:$a)),
(CVT_bf16_s32 (SELP_u32ii 1, 0, $a), CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
(CVT_bf16_s32 (SELP_b32ii 1, 0, $a), CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
def : Pat<(bf16 (sint_to_fp i16:$a)),
(CVT_bf16_s16 $a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
def : Pat<(bf16 (sint_to_fp i32:$a)),
Expand All @@ -2720,7 +2698,7 @@ def : Pat<(bf16 (sint_to_fp i64:$a)),

// uint -> bf16
def : Pat<(bf16 (uint_to_fp i1:$a)),
(CVT_bf16_u32 (SELP_u32ii 1, 0, $a), CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
(CVT_bf16_u32 (SELP_b32ii 1, 0, $a), CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
def : Pat<(bf16 (uint_to_fp i16:$a)),
(CVT_bf16_u16 $a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
def : Pat<(bf16 (uint_to_fp i32:$a)),
Expand All @@ -2730,7 +2708,7 @@ def : Pat<(bf16 (uint_to_fp i64:$a)),

// sint -> f32
def : Pat<(f32 (sint_to_fp i1:$a)),
(CVT_f32_s32 (SELP_s32ii -1, 0, $a), CvtRN)>;
(CVT_f32_s32 (SELP_b32ii -1, 0, $a), CvtRN)>;
def : Pat<(f32 (sint_to_fp i16:$a)),
(CVT_f32_s16 $a, CvtRN)>;
def : Pat<(f32 (sint_to_fp i32:$a)),
Expand All @@ -2740,7 +2718,7 @@ def : Pat<(f32 (sint_to_fp i64:$a)),

// uint -> f32
def : Pat<(f32 (uint_to_fp i1:$a)),
(CVT_f32_u32 (SELP_u32ii 1, 0, $a), CvtRN)>;
(CVT_f32_u32 (SELP_b32ii 1, 0, $a), CvtRN)>;
def : Pat<(f32 (uint_to_fp i16:$a)),
(CVT_f32_u16 $a, CvtRN)>;
def : Pat<(f32 (uint_to_fp i32:$a)),
Expand All @@ -2750,7 +2728,7 @@ def : Pat<(f32 (uint_to_fp i64:$a)),

// sint -> f64
def : Pat<(f64 (sint_to_fp i1:$a)),
(CVT_f64_s32 (SELP_s32ii -1, 0, $a), CvtRN)>;
(CVT_f64_s32 (SELP_b32ii -1, 0, $a), CvtRN)>;
def : Pat<(f64 (sint_to_fp i16:$a)),
(CVT_f64_s16 $a, CvtRN)>;
def : Pat<(f64 (sint_to_fp i32:$a)),
Expand All @@ -2760,7 +2738,7 @@ def : Pat<(f64 (sint_to_fp i64:$a)),

// uint -> f64
def : Pat<(f64 (uint_to_fp i1:$a)),
(CVT_f64_u32 (SELP_u32ii 1, 0, $a), CvtRN)>;
(CVT_f64_u32 (SELP_b32ii 1, 0, $a), CvtRN)>;
def : Pat<(f64 (uint_to_fp i16:$a)),
(CVT_f64_u16 $a, CvtRN)>;
def : Pat<(f64 (uint_to_fp i32:$a)),
Expand Down Expand Up @@ -2862,27 +2840,27 @@ def : Pat<(i64 (fp_to_uint f64:$a)),

// sext i1
def : Pat<(i16 (sext i1:$a)),
(SELP_s16ii -1, 0, $a)>;
(SELP_b16ii -1, 0, $a)>;
def : Pat<(i32 (sext i1:$a)),
(SELP_s32ii -1, 0, $a)>;
(SELP_b32ii -1, 0, $a)>;
def : Pat<(i64 (sext i1:$a)),
(SELP_s64ii -1, 0, $a)>;
(SELP_b64ii -1, 0, $a)>;

// zext i1
def : Pat<(i16 (zext i1:$a)),
(SELP_u16ii 1, 0, $a)>;
(SELP_b16ii 1, 0, $a)>;
def : Pat<(i32 (zext i1:$a)),
(SELP_u32ii 1, 0, $a)>;
(SELP_b32ii 1, 0, $a)>;
def : Pat<(i64 (zext i1:$a)),
(SELP_u64ii 1, 0, $a)>;
(SELP_b64ii 1, 0, $a)>;

// anyext i1
def : Pat<(i16 (anyext i1:$a)),
(SELP_u16ii -1, 0, $a)>;
(SELP_b16ii -1, 0, $a)>;
def : Pat<(i32 (anyext i1:$a)),
(SELP_u32ii -1, 0, $a)>;
(SELP_b32ii -1, 0, $a)>;
def : Pat<(i64 (anyext i1:$a)),
(SELP_u64ii -1, 0, $a)>;
(SELP_b64ii -1, 0, $a)>;

// sext i16
def : Pat<(i32 (sext i16:$a)),
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/NVPTX/add-sub-128bit.ll
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ define i128 @test_add(i128 %a, i128 %b) {
; NOCARRY: add.s64
; NOCARRY-NEXT: add.s64
; NOCARRY-NEXT: setp.lt.u64
; NOCARRY-NEXT: selp.u64
; NOCARRY-NEXT: selp.b64
; NOCARRY-NEXT: add.s64

; CARRY: add.cc.s64
Expand All @@ -23,7 +23,7 @@ define i128 @test_add(i128 %a, i128 %b) {
define i128 @test_sub(i128 %a, i128 %b) {
; NOCARRY: sub.s64
; NOCARRY-NEXT: setp.lt.u64
; NOCARRY-NEXT: selp.s64
; NOCARRY-NEXT: selp.b64
; NOCARRY-NEXT: add.s64
; NOCARRY-NEXT: sub.s64

Expand Down
8 changes: 4 additions & 4 deletions llvm/test/CodeGen/NVPTX/bf16-instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,7 @@ define bfloat @test_uitofp_i1(i1 %a) {
; SM70-NEXT: ld.param.u8 %rs1, [test_uitofp_i1_param_0];
; SM70-NEXT: and.b16 %rs2, %rs1, 1;
; SM70-NEXT: setp.eq.b16 %p1, %rs2, 1;
; SM70-NEXT: selp.u32 %r1, 1, 0, %p1;
; SM70-NEXT: selp.b32 %r1, 1, 0, %p1;
; SM70-NEXT: cvt.rn.f32.u32 %f1, %r1;
; SM70-NEXT: mov.b32 %r2, %f1;
; SM70-NEXT: bfe.u32 %r3, %r2, 16, 1;
Expand All @@ -1148,7 +1148,7 @@ define bfloat @test_uitofp_i1(i1 %a) {
; SM80-NEXT: ld.param.u8 %rs1, [test_uitofp_i1_param_0];
; SM80-NEXT: and.b16 %rs2, %rs1, 1;
; SM80-NEXT: setp.eq.b16 %p1, %rs2, 1;
; SM80-NEXT: selp.u32 %r1, 1, 0, %p1;
; SM80-NEXT: selp.b32 %r1, 1, 0, %p1;
; SM80-NEXT: cvt.rn.f32.u32 %f1, %r1;
; SM80-NEXT: cvt.rn.bf16.f32 %rs3, %f1;
; SM80-NEXT: st.param.b16 [func_retval0], %rs3;
Expand All @@ -1165,7 +1165,7 @@ define bfloat @test_uitofp_i1(i1 %a) {
; SM80-FTZ-NEXT: ld.param.u8 %rs1, [test_uitofp_i1_param_0];
; SM80-FTZ-NEXT: and.b16 %rs2, %rs1, 1;
; SM80-FTZ-NEXT: setp.eq.b16 %p1, %rs2, 1;
; SM80-FTZ-NEXT: selp.u32 %r1, 1, 0, %p1;
; SM80-FTZ-NEXT: selp.b32 %r1, 1, 0, %p1;
; SM80-FTZ-NEXT: cvt.rn.f32.u32 %f1, %r1;
; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs3, %f1;
; SM80-FTZ-NEXT: st.param.b16 [func_retval0], %rs3;
Expand All @@ -1181,7 +1181,7 @@ define bfloat @test_uitofp_i1(i1 %a) {
; SM90-NEXT: ld.param.u8 %rs1, [test_uitofp_i1_param_0];
; SM90-NEXT: and.b16 %rs2, %rs1, 1;
; SM90-NEXT: setp.eq.b16 %p1, %rs2, 1;
; SM90-NEXT: selp.u32 %r1, 1, 0, %p1;
; SM90-NEXT: selp.b32 %r1, 1, 0, %p1;
; SM90-NEXT: cvt.rn.bf16.u32 %rs3, %r1;
; SM90-NEXT: st.param.b16 [func_retval0], %rs3;
; SM90-NEXT: ret;
Expand Down
Loading