Skip to content

Commit 02161c6

Browse files
authored
[NVPTX] Misc table-gen cleanup (NFC) (#142877)
1 parent 34a1b8c commit 02161c6

File tree

3 files changed

+1065
-2475
lines changed

3 files changed

+1065
-2475
lines changed

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 82 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
160160
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
161161

162162
def True : Predicate<"true">;
163-
def False : Predicate<"false">;
164163

165164
class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
166165
class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
@@ -257,6 +256,11 @@ def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
257256
// "prmt.b32${mode}">;
258257
// ---> "prmt.b32${mode} \t$d, $a, $b, $c;"
259258
//
259+
// * BasicFlagsNVPTXInst<(outs Int64Regs:$state),
260+
// (ins ADDR:$addr),
261+
// "mbarrier.arrive.b64">;
262+
// ---> "mbarrier.arrive.b64 \t$state, [$addr];"
263+
//
260264
class BasicFlagsNVPTXInst<dag outs_dag, dag ins_dag, dag flags_dag, string asmstr,
261265
list<dag> pattern = []>
262266
: NVPTXInst<
@@ -274,7 +278,11 @@ class BasicFlagsNVPTXInst<dag outs_dag, dag ins_dag, dag flags_dag, string asmst
274278
!if(!or(!empty(ins_dag), !empty(outs_dag)), "", ", "),
275279
!interleave(
276280
!foreach(i, !range(!size(ins_dag)),
277-
"$" # !getdagname(ins_dag, i)),
281+
!if(!eq(!cast<string>(!getdagarg<DAGOperand>(ins_dag, i)), "ADDR"),
282+
"[$" # !getdagname(ins_dag, i) # "]",
283+
"$" # !getdagname(ins_dag, i)
284+
)
285+
),
278286
", "))),
279287
";"),
280288
pattern>;
@@ -956,31 +964,17 @@ def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>;
956964
def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>;
957965

958966
// Matchers for signed, unsigned mul.wide ISD nodes.
959-
def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)),
960-
(MULWIDES32 $a, $b)>,
961-
Requires<[doMulWide]>;
962-
def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)),
963-
(MULWIDES32Imm $a, imm:$b)>,
964-
Requires<[doMulWide]>;
965-
def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)),
966-
(MULWIDEU32 $a, $b)>,
967-
Requires<[doMulWide]>;
968-
def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)),
969-
(MULWIDEU32Imm $a, imm:$b)>,
970-
Requires<[doMulWide]>;
967+
let Predicates = [doMulWide] in {
968+
def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), (MULWIDES32 $a, $b)>;
969+
def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), (MULWIDES32Imm $a, imm:$b)>;
970+
def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), (MULWIDEU32 $a, $b)>;
971+
def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)), (MULWIDEU32Imm $a, imm:$b)>;
971972

972-
def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)),
973-
(MULWIDES64 $a, $b)>,
974-
Requires<[doMulWide]>;
975-
def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)),
976-
(MULWIDES64Imm $a, imm:$b)>,
977-
Requires<[doMulWide]>;
978-
def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)),
979-
(MULWIDEU64 $a, $b)>,
980-
Requires<[doMulWide]>;
981-
def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)),
982-
(MULWIDEU64Imm $a, imm:$b)>,
983-
Requires<[doMulWide]>;
973+
def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)), (MULWIDES64 $a, $b)>;
974+
def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)), (MULWIDES64Imm $a, imm:$b)>;
975+
def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)), (MULWIDEU64 $a, $b)>;
976+
def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), (MULWIDEU64Imm $a, imm:$b)>;
977+
}
984978

985979
// Predicates used for converting some patterns to mul.wide.
986980
def SInt32Const : PatLeaf<(imm), [{
@@ -1106,18 +1100,12 @@ defm MAD32 : MAD<"mad.lo.s32", i32, Int32Regs, i32imm>;
11061100
defm MAD64 : MAD<"mad.lo.s64", i64, Int64Regs, i64imm>;
11071101
}
11081102

1109-
def INEG16 :
1110-
BasicNVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
1111-
"neg.s16",
1112-
[(set i16:$dst, (ineg i16:$src))]>;
1113-
def INEG32 :
1114-
BasicNVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src),
1115-
"neg.s32",
1116-
[(set i32:$dst, (ineg i32:$src))]>;
1117-
def INEG64 :
1118-
BasicNVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src),
1119-
"neg.s64",
1120-
[(set i64:$dst, (ineg i64:$src))]>;
1103+
foreach t = [I16RT, I32RT, I64RT] in {
1104+
def NEG_S # t.Size :
1105+
BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
1106+
"neg.s" # t.Size,
1107+
[(set t.Ty:$dst, (ineg t.Ty:$src))]>;
1108+
}
11211109

11221110
//-----------------------------------
11231111
// Floating Point Arithmetic
@@ -1538,7 +1526,7 @@ def bfi : SDNode<"NVPTXISD::BFI", SDTBFI>;
15381526

15391527
def SDTPRMT :
15401528
SDTypeProfile<1, 4, [SDTCisVT<0, i32>, SDTCisVT<1, i32>,
1541-
SDTCisVT<2, i32>, SDTCisVT<3, i32>, SDTCisVT<4, i32>,]>;
1529+
SDTCisVT<2, i32>, SDTCisVT<3, i32>, SDTCisVT<4, i32>]>;
15421530
def prmt : SDNode<"NVPTXISD::PRMT", SDTPRMT>;
15431531

15441532
multiclass BFE<string Instr, ValueType T, RegisterClass RC> {
@@ -1961,15 +1949,15 @@ multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
19611949
// f16 -> pred
19621950
def : Pat<(i1 (OpNode f16:$a, f16:$b)),
19631951
(SETP_f16rr $a, $b, ModeFTZ)>,
1964-
Requires<[useFP16Math,doF32FTZ]>;
1952+
Requires<[useFP16Math, doF32FTZ]>;
19651953
def : Pat<(i1 (OpNode f16:$a, f16:$b)),
19661954
(SETP_f16rr $a, $b, Mode)>,
19671955
Requires<[useFP16Math]>;
19681956

19691957
// bf16 -> pred
19701958
def : Pat<(i1 (OpNode bf16:$a, bf16:$b)),
19711959
(SETP_bf16rr $a, $b, ModeFTZ)>,
1972-
Requires<[hasBF16Math,doF32FTZ]>;
1960+
Requires<[hasBF16Math, doF32FTZ]>;
19731961
def : Pat<(i1 (OpNode bf16:$a, bf16:$b)),
19741962
(SETP_bf16rr $a, $b, Mode)>,
19751963
Requires<[hasBF16Math]>;
@@ -2497,24 +2485,20 @@ def : Pat<(f16 (uint_to_fp i32:$a)), (CVT_f16_u32 $a, CvtRN)>;
24972485
def : Pat<(f16 (uint_to_fp i64:$a)), (CVT_f16_u64 $a, CvtRN)>;
24982486

24992487
// sint -> bf16
2500-
def : Pat<(bf16 (sint_to_fp i1:$a)), (CVT_bf16_s32 (SELP_b32ii 1, 0, $a), CvtRN)>,
2501-
Requires<[hasPTX<78>, hasSM<90>]>;
2502-
def : Pat<(bf16 (sint_to_fp i16:$a)), (CVT_bf16_s16 $a, CvtRN)>,
2503-
Requires<[hasPTX<78>, hasSM<90>]>;
2504-
def : Pat<(bf16 (sint_to_fp i32:$a)), (CVT_bf16_s32 $a, CvtRN)>,
2505-
Requires<[hasPTX<78>, hasSM<90>]>;
2506-
def : Pat<(bf16 (sint_to_fp i64:$a)), (CVT_bf16_s64 $a, CvtRN)>,
2507-
Requires<[hasPTX<78>, hasSM<90>]>;
2488+
let Predicates = [hasPTX<78>, hasSM<90>] in {
2489+
def : Pat<(bf16 (sint_to_fp i1:$a)), (CVT_bf16_s32 (SELP_b32ii 1, 0, $a), CvtRN)>;
2490+
def : Pat<(bf16 (sint_to_fp i16:$a)), (CVT_bf16_s16 $a, CvtRN)>;
2491+
def : Pat<(bf16 (sint_to_fp i32:$a)), (CVT_bf16_s32 $a, CvtRN)>;
2492+
def : Pat<(bf16 (sint_to_fp i64:$a)), (CVT_bf16_s64 $a, CvtRN)>;
2493+
}
25082494

25092495
// uint -> bf16
2510-
def : Pat<(bf16 (uint_to_fp i1:$a)), (CVT_bf16_u32 (SELP_b32ii 1, 0, $a), CvtRN)>,
2511-
Requires<[hasPTX<78>, hasSM<90>]>;
2512-
def : Pat<(bf16 (uint_to_fp i16:$a)), (CVT_bf16_u16 $a, CvtRN)>,
2513-
Requires<[hasPTX<78>, hasSM<90>]>;
2514-
def : Pat<(bf16 (uint_to_fp i32:$a)), (CVT_bf16_u32 $a, CvtRN)>,
2515-
Requires<[hasPTX<78>, hasSM<90>]>;
2516-
def : Pat<(bf16 (uint_to_fp i64:$a)), (CVT_bf16_u64 $a, CvtRN)>,
2517-
Requires<[hasPTX<78>, hasSM<90>]>;
2496+
let Predicates = [hasPTX<78>, hasSM<90>] in {
2497+
def : Pat<(bf16 (uint_to_fp i1:$a)), (CVT_bf16_u32 (SELP_b32ii 1, 0, $a), CvtRN)>;
2498+
def : Pat<(bf16 (uint_to_fp i16:$a)), (CVT_bf16_u16 $a, CvtRN)>;
2499+
def : Pat<(bf16 (uint_to_fp i32:$a)), (CVT_bf16_u32 $a, CvtRN)>;
2500+
def : Pat<(bf16 (uint_to_fp i64:$a)), (CVT_bf16_u64 $a, CvtRN)>;
2501+
}
25182502

25192503
// sint -> f32
25202504
def : Pat<(f32 (sint_to_fp i1:$a)), (CVT_f32_s32 (SELP_b32ii -1, 0, $a), CvtRN)>;
@@ -2565,27 +2549,25 @@ def : Pat<(i16 (fp_to_uint bf16:$a)), (CVT_u16_bf16 $a, CvtRZI)>;
25652549
def : Pat<(i32 (fp_to_uint bf16:$a)), (CVT_u32_bf16 $a, CvtRZI)>;
25662550
def : Pat<(i64 (fp_to_uint bf16:$a)), (CVT_u64_bf16 $a, CvtRZI)>;
25672551
// f32 -> sint
2568-
def : Pat<(i1 (fp_to_sint f32:$a)), (SETP_b32ri $a, 0, CmpEQ)>;
2569-
def : Pat<(i16 (fp_to_sint f32:$a)), (CVT_s16_f32 $a, CvtRZI_FTZ)>,
2570-
Requires<[doF32FTZ]>;
2552+
let Predicates = [doF32FTZ] in {
2553+
def : Pat<(i16 (fp_to_sint f32:$a)), (CVT_s16_f32 $a, CvtRZI_FTZ)>;
2554+
def : Pat<(i32 (fp_to_sint f32:$a)), (CVT_s32_f32 $a, CvtRZI_FTZ)>;
2555+
def : Pat<(i64 (fp_to_sint f32:$a)), (CVT_s64_f32 $a, CvtRZI_FTZ)>;
2556+
}
2557+
def : Pat<(i1 (fp_to_sint f32:$a)), (SETP_b32ri $a, 0, CmpEQ)>;
25712558
def : Pat<(i16 (fp_to_sint f32:$a)), (CVT_s16_f32 $a, CvtRZI)>;
2572-
def : Pat<(i32 (fp_to_sint f32:$a)), (CVT_s32_f32 $a, CvtRZI_FTZ)>,
2573-
Requires<[doF32FTZ]>;
25742559
def : Pat<(i32 (fp_to_sint f32:$a)), (CVT_s32_f32 $a, CvtRZI)>;
2575-
def : Pat<(i64 (fp_to_sint f32:$a)), (CVT_s64_f32 $a, CvtRZI_FTZ)>,
2576-
Requires<[doF32FTZ]>;
25772560
def : Pat<(i64 (fp_to_sint f32:$a)), (CVT_s64_f32 $a, CvtRZI)>;
25782561

25792562
// f32 -> uint
2563+
let Predicates = [doF32FTZ] in {
2564+
def : Pat<(i16 (fp_to_uint f32:$a)), (CVT_u16_f32 $a, CvtRZI_FTZ)>;
2565+
def : Pat<(i32 (fp_to_uint f32:$a)), (CVT_u32_f32 $a, CvtRZI_FTZ)>;
2566+
def : Pat<(i64 (fp_to_uint f32:$a)), (CVT_u64_f32 $a, CvtRZI_FTZ)>;
2567+
}
25802568
def : Pat<(i1 (fp_to_uint f32:$a)), (SETP_b32ri $a, 0, CmpEQ)>;
2581-
def : Pat<(i16 (fp_to_uint f32:$a)), (CVT_u16_f32 $a, CvtRZI_FTZ)>,
2582-
Requires<[doF32FTZ]>;
25832569
def : Pat<(i16 (fp_to_uint f32:$a)), (CVT_u16_f32 $a, CvtRZI)>;
2584-
def : Pat<(i32 (fp_to_uint f32:$a)), (CVT_u32_f32 $a, CvtRZI_FTZ)>,
2585-
Requires<[doF32FTZ]>;
25862570
def : Pat<(i32 (fp_to_uint f32:$a)), (CVT_u32_f32 $a, CvtRZI)>;
2587-
def : Pat<(i64 (fp_to_uint f32:$a)), (CVT_u64_f32 $a, CvtRZI_FTZ)>,
2588-
Requires<[doF32FTZ]>;
25892571
def : Pat<(i64 (fp_to_uint f32:$a)), (CVT_u64_f32 $a, CvtRZI)>;
25902572

25912573
// f64 -> sint
@@ -2707,28 +2689,24 @@ let hasSideEffects = false in {
27072689

27082690
// PTX 7.1 lets you avoid a temp register and just use _ as a "sink" for the
27092691
// unused high/low part.
2710-
def I32toI16H_Sink : NVPTXInst<(outs Int16Regs:$high),
2711-
(ins Int32Regs:$s),
2712-
"mov.b32 \t{{_, $high}}, $s;",
2713-
[]>, Requires<[hasPTX<71>]>;
2714-
def I32toI16L_Sink : NVPTXInst<(outs Int16Regs:$low),
2715-
(ins Int32Regs:$s),
2716-
"mov.b32 \t{{$low, _}}, $s;",
2717-
[]>, Requires<[hasPTX<71>]>;
2718-
def I64toI32H_Sink : NVPTXInst<(outs Int32Regs:$high),
2719-
(ins Int64Regs:$s),
2720-
"mov.b64 \t{{_, $high}}, $s;",
2721-
[]>, Requires<[hasPTX<71>]>;
2722-
def I64toI32L_Sink : NVPTXInst<(outs Int32Regs:$low),
2723-
(ins Int64Regs:$s),
2724-
"mov.b64 \t{{$low, _}}, $s;",
2725-
[]>, Requires<[hasPTX<71>]>;
2692+
let Predicates = [hasPTX<71>] in {
2693+
def I32toI16H_Sink : NVPTXInst<(outs Int16Regs:$high), (ins Int32Regs:$s),
2694+
"mov.b32 \t{{_, $high}}, $s;", []>;
2695+
def I32toI16L_Sink : NVPTXInst<(outs Int16Regs:$low), (ins Int32Regs:$s),
2696+
"mov.b32 \t{{$low, _}}, $s;", []>;
2697+
def I64toI32H_Sink : NVPTXInst<(outs Int32Regs:$high), (ins Int64Regs:$s),
2698+
"mov.b64 \t{{_, $high}}, $s;", []>;
2699+
def I64toI32L_Sink : NVPTXInst<(outs Int32Regs:$low), (ins Int64Regs:$s),
2700+
"mov.b64 \t{{$low, _}}, $s;", []>;
2701+
}
27262702
}
27272703

2728-
def : Pat<(i16 (trunc (srl i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>, Requires<[hasPTX<71>]>;
2729-
def : Pat<(i16 (trunc (sra i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>, Requires<[hasPTX<71>]>;
2730-
def : Pat<(i32 (trunc (srl i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>, Requires<[hasPTX<71>]>;
2731-
def : Pat<(i32 (trunc (sra i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>, Requires<[hasPTX<71>]>;
2704+
let Predicates = [hasPTX<71>] in {
2705+
def : Pat<(i16 (trunc (srl i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>;
2706+
def : Pat<(i16 (trunc (sra i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>;
2707+
def : Pat<(i32 (trunc (srl i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>;
2708+
def : Pat<(i32 (trunc (sra i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>;
2709+
}
27322710

27332711
// Fall back to the old way if we don't have PTX 7.1.
27342712
def : Pat<(i16 (trunc (srl i32:$s, (i32 16)))), (I32toI16H $s)>;
@@ -3061,29 +3039,19 @@ def stacksave :
30613039
SDNode<"NVPTXISD::STACKSAVE", SDTIntLeaf,
30623040
[SDNPHasChain, SDNPSideEffect]>;
30633041

3064-
def STACKRESTORE_32 :
3065-
BasicNVPTXInst<(outs), (ins Int32Regs:$ptr),
3066-
"stackrestore.u32",
3067-
[(stackrestore i32:$ptr)]>,
3068-
Requires<[hasPTX<73>, hasSM<52>]>;
3069-
3070-
def STACKSAVE_32 :
3071-
BasicNVPTXInst<(outs Int32Regs:$dst), (ins),
3072-
"stacksave.u32",
3073-
[(set i32:$dst, (i32 stacksave))]>,
3074-
Requires<[hasPTX<73>, hasSM<52>]>;
3075-
3076-
def STACKRESTORE_64 :
3077-
BasicNVPTXInst<(outs), (ins Int64Regs:$ptr),
3078-
"stackrestore.u64",
3079-
[(stackrestore i64:$ptr)]>,
3080-
Requires<[hasPTX<73>, hasSM<52>]>;
3081-
3082-
def STACKSAVE_64 :
3083-
BasicNVPTXInst<(outs Int64Regs:$dst), (ins),
3084-
"stacksave.u64",
3085-
[(set i64:$dst, (i64 stacksave))]>,
3086-
Requires<[hasPTX<73>, hasSM<52>]>;
3042+
let Predicates = [hasPTX<73>, hasSM<52>] in {
3043+
foreach t = [I32RT, I64RT] in {
3044+
def STACKRESTORE_ # t.Size :
3045+
BasicNVPTXInst<(outs), (ins t.RC:$ptr),
3046+
"stackrestore.u" # t.Size,
3047+
[(stackrestore t.Ty:$ptr)]>;
3048+
3049+
def STACKSAVE_ # t.Size :
3050+
BasicNVPTXInst<(outs t.RC:$dst), (ins),
3051+
"stacksave.u" # t.Size,
3052+
[(set t.Ty:$dst, (t.Ty stacksave))]>;
3053+
}
3054+
}
30873055

30883056
include "NVPTXIntrinsics.td"
30893057

@@ -3124,7 +3092,7 @@ def : Pat <
31243092
////////////////////////////////////////////////////////////////////////////////
31253093

31263094
class NVPTXFenceInst<string scope, string sem, Predicate ptx>:
3127-
NVPTXInst<(outs), (ins), "fence."#sem#"."#scope#";", []>,
3095+
BasicNVPTXInst<(outs), (ins), "fence."#sem#"."#scope>,
31283096
Requires<[ptx, hasSM<70>]>;
31293097

31303098
foreach scope = ["sys", "gpu", "cluster", "cta"] in {

0 commit comments

Comments
 (0)