@@ -160,7 +160,6 @@ def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
160
160
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
161
161
162
162
def True : Predicate<"true">;
163
- def False : Predicate<"false">;
164
163
165
164
class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
166
165
class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
@@ -257,6 +256,11 @@ def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
257
256
// "prmt.b32${mode}">;
258
257
// ---> "prmt.b32${mode} \t$d, $a, $b, $c;"
259
258
//
259
+ // * BasicFlagsNVPTXInst<(outs Int64Regs:$state),
260
+ // (ins ADDR:$addr),
261
+ // "mbarrier.arrive.b64">;
262
+ // ---> "mbarrier.arrive.b64 \t$state, [$addr];"
263
+ //
260
264
class BasicFlagsNVPTXInst<dag outs_dag, dag ins_dag, dag flags_dag, string asmstr,
261
265
list<dag> pattern = []>
262
266
: NVPTXInst<
@@ -274,7 +278,11 @@ class BasicFlagsNVPTXInst<dag outs_dag, dag ins_dag, dag flags_dag, string asmst
274
278
!if(!or(!empty(ins_dag), !empty(outs_dag)), "", ", "),
275
279
!interleave(
276
280
!foreach(i, !range(!size(ins_dag)),
277
- "$" # !getdagname(ins_dag, i)),
281
+ !if(!eq(!cast<string>(!getdagarg<DAGOperand>(ins_dag, i)), "ADDR"),
282
+ "[$" # !getdagname(ins_dag, i) # "]",
283
+ "$" # !getdagname(ins_dag, i)
284
+ )
285
+ ),
278
286
", "))),
279
287
";"),
280
288
pattern>;
@@ -956,31 +964,17 @@ def mul_wide_signed : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>;
956
964
def mul_wide_unsigned : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>;
957
965
958
966
// Matchers for signed, unsigned mul.wide ISD nodes.
959
- def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)),
960
- (MULWIDES32 $a, $b)>,
961
- Requires<[doMulWide]>;
962
- def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)),
963
- (MULWIDES32Imm $a, imm:$b)>,
964
- Requires<[doMulWide]>;
965
- def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)),
966
- (MULWIDEU32 $a, $b)>,
967
- Requires<[doMulWide]>;
968
- def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)),
969
- (MULWIDEU32Imm $a, imm:$b)>,
970
- Requires<[doMulWide]>;
967
+ let Predicates = [doMulWide] in {
968
+ def : Pat<(i32 (mul_wide_signed i16:$a, i16:$b)), (MULWIDES32 $a, $b)>;
969
+ def : Pat<(i32 (mul_wide_signed i16:$a, imm:$b)), (MULWIDES32Imm $a, imm:$b)>;
970
+ def : Pat<(i32 (mul_wide_unsigned i16:$a, i16:$b)), (MULWIDEU32 $a, $b)>;
971
+ def : Pat<(i32 (mul_wide_unsigned i16:$a, imm:$b)), (MULWIDEU32Imm $a, imm:$b)>;
971
972
972
- def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)),
973
- (MULWIDES64 $a, $b)>,
974
- Requires<[doMulWide]>;
975
- def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)),
976
- (MULWIDES64Imm $a, imm:$b)>,
977
- Requires<[doMulWide]>;
978
- def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)),
979
- (MULWIDEU64 $a, $b)>,
980
- Requires<[doMulWide]>;
981
- def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)),
982
- (MULWIDEU64Imm $a, imm:$b)>,
983
- Requires<[doMulWide]>;
973
+ def : Pat<(i64 (mul_wide_signed i32:$a, i32:$b)), (MULWIDES64 $a, $b)>;
974
+ def : Pat<(i64 (mul_wide_signed i32:$a, imm:$b)), (MULWIDES64Imm $a, imm:$b)>;
975
+ def : Pat<(i64 (mul_wide_unsigned i32:$a, i32:$b)), (MULWIDEU64 $a, $b)>;
976
+ def : Pat<(i64 (mul_wide_unsigned i32:$a, imm:$b)), (MULWIDEU64Imm $a, imm:$b)>;
977
+ }
984
978
985
979
// Predicates used for converting some patterns to mul.wide.
986
980
def SInt32Const : PatLeaf<(imm), [{
@@ -1106,18 +1100,12 @@ defm MAD32 : MAD<"mad.lo.s32", i32, Int32Regs, i32imm>;
1106
1100
defm MAD64 : MAD<"mad.lo.s64", i64, Int64Regs, i64imm>;
1107
1101
}
1108
1102
1109
- def INEG16 :
1110
- BasicNVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
1111
- "neg.s16",
1112
- [(set i16:$dst, (ineg i16:$src))]>;
1113
- def INEG32 :
1114
- BasicNVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src),
1115
- "neg.s32",
1116
- [(set i32:$dst, (ineg i32:$src))]>;
1117
- def INEG64 :
1118
- BasicNVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src),
1119
- "neg.s64",
1120
- [(set i64:$dst, (ineg i64:$src))]>;
1103
+ foreach t = [I16RT, I32RT, I64RT] in {
1104
+ def NEG_S # t.Size :
1105
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src),
1106
+ "neg.s" # t.Size,
1107
+ [(set t.Ty:$dst, (ineg t.Ty:$src))]>;
1108
+ }
1121
1109
1122
1110
//-----------------------------------
1123
1111
// Floating Point Arithmetic
@@ -1538,7 +1526,7 @@ def bfi : SDNode<"NVPTXISD::BFI", SDTBFI>;
1538
1526
1539
1527
def SDTPRMT :
1540
1528
SDTypeProfile<1, 4, [SDTCisVT<0, i32>, SDTCisVT<1, i32>,
1541
- SDTCisVT<2, i32>, SDTCisVT<3, i32>, SDTCisVT<4, i32>, ]>;
1529
+ SDTCisVT<2, i32>, SDTCisVT<3, i32>, SDTCisVT<4, i32>]>;
1542
1530
def prmt : SDNode<"NVPTXISD::PRMT", SDTPRMT>;
1543
1531
1544
1532
multiclass BFE<string Instr, ValueType T, RegisterClass RC> {
@@ -1961,15 +1949,15 @@ multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
1961
1949
// f16 -> pred
1962
1950
def : Pat<(i1 (OpNode f16:$a, f16:$b)),
1963
1951
(SETP_f16rr $a, $b, ModeFTZ)>,
1964
- Requires<[useFP16Math,doF32FTZ]>;
1952
+ Requires<[useFP16Math, doF32FTZ]>;
1965
1953
def : Pat<(i1 (OpNode f16:$a, f16:$b)),
1966
1954
(SETP_f16rr $a, $b, Mode)>,
1967
1955
Requires<[useFP16Math]>;
1968
1956
1969
1957
// bf16 -> pred
1970
1958
def : Pat<(i1 (OpNode bf16:$a, bf16:$b)),
1971
1959
(SETP_bf16rr $a, $b, ModeFTZ)>,
1972
- Requires<[hasBF16Math,doF32FTZ]>;
1960
+ Requires<[hasBF16Math, doF32FTZ]>;
1973
1961
def : Pat<(i1 (OpNode bf16:$a, bf16:$b)),
1974
1962
(SETP_bf16rr $a, $b, Mode)>,
1975
1963
Requires<[hasBF16Math]>;
@@ -2497,24 +2485,20 @@ def : Pat<(f16 (uint_to_fp i32:$a)), (CVT_f16_u32 $a, CvtRN)>;
2497
2485
def : Pat<(f16 (uint_to_fp i64:$a)), (CVT_f16_u64 $a, CvtRN)>;
2498
2486
2499
2487
// sint -> bf16
2500
- def : Pat<(bf16 (sint_to_fp i1:$a)), (CVT_bf16_s32 (SELP_b32ii 1, 0, $a), CvtRN)>,
2501
- Requires<[hasPTX<78>, hasSM<90>]>;
2502
- def : Pat<(bf16 (sint_to_fp i16:$a)), (CVT_bf16_s16 $a, CvtRN)>,
2503
- Requires<[hasPTX<78>, hasSM<90>]>;
2504
- def : Pat<(bf16 (sint_to_fp i32:$a)), (CVT_bf16_s32 $a, CvtRN)>,
2505
- Requires<[hasPTX<78>, hasSM<90>]>;
2506
- def : Pat<(bf16 (sint_to_fp i64:$a)), (CVT_bf16_s64 $a, CvtRN)>,
2507
- Requires<[hasPTX<78>, hasSM<90>]>;
2488
+ let Predicates = [hasPTX<78>, hasSM<90>] in {
2489
+ def : Pat<(bf16 (sint_to_fp i1:$a)), (CVT_bf16_s32 (SELP_b32ii 1, 0, $a), CvtRN)>;
2490
+ def : Pat<(bf16 (sint_to_fp i16:$a)), (CVT_bf16_s16 $a, CvtRN)>;
2491
+ def : Pat<(bf16 (sint_to_fp i32:$a)), (CVT_bf16_s32 $a, CvtRN)>;
2492
+ def : Pat<(bf16 (sint_to_fp i64:$a)), (CVT_bf16_s64 $a, CvtRN)>;
2493
+ }
2508
2494
2509
2495
// uint -> bf16
2510
- def : Pat<(bf16 (uint_to_fp i1:$a)), (CVT_bf16_u32 (SELP_b32ii 1, 0, $a), CvtRN)>,
2511
- Requires<[hasPTX<78>, hasSM<90>]>;
2512
- def : Pat<(bf16 (uint_to_fp i16:$a)), (CVT_bf16_u16 $a, CvtRN)>,
2513
- Requires<[hasPTX<78>, hasSM<90>]>;
2514
- def : Pat<(bf16 (uint_to_fp i32:$a)), (CVT_bf16_u32 $a, CvtRN)>,
2515
- Requires<[hasPTX<78>, hasSM<90>]>;
2516
- def : Pat<(bf16 (uint_to_fp i64:$a)), (CVT_bf16_u64 $a, CvtRN)>,
2517
- Requires<[hasPTX<78>, hasSM<90>]>;
2496
+ let Predicates = [hasPTX<78>, hasSM<90>] in {
2497
+ def : Pat<(bf16 (uint_to_fp i1:$a)), (CVT_bf16_u32 (SELP_b32ii 1, 0, $a), CvtRN)>;
2498
+ def : Pat<(bf16 (uint_to_fp i16:$a)), (CVT_bf16_u16 $a, CvtRN)>;
2499
+ def : Pat<(bf16 (uint_to_fp i32:$a)), (CVT_bf16_u32 $a, CvtRN)>;
2500
+ def : Pat<(bf16 (uint_to_fp i64:$a)), (CVT_bf16_u64 $a, CvtRN)>;
2501
+ }
2518
2502
2519
2503
// sint -> f32
2520
2504
def : Pat<(f32 (sint_to_fp i1:$a)), (CVT_f32_s32 (SELP_b32ii -1, 0, $a), CvtRN)>;
@@ -2565,27 +2549,25 @@ def : Pat<(i16 (fp_to_uint bf16:$a)), (CVT_u16_bf16 $a, CvtRZI)>;
2565
2549
def : Pat<(i32 (fp_to_uint bf16:$a)), (CVT_u32_bf16 $a, CvtRZI)>;
2566
2550
def : Pat<(i64 (fp_to_uint bf16:$a)), (CVT_u64_bf16 $a, CvtRZI)>;
2567
2551
// f32 -> sint
2568
- def : Pat<(i1 (fp_to_sint f32:$a)), (SETP_b32ri $a, 0, CmpEQ)>;
2569
- def : Pat<(i16 (fp_to_sint f32:$a)), (CVT_s16_f32 $a, CvtRZI_FTZ)>,
2570
- Requires<[doF32FTZ]>;
2552
+ let Predicates = [doF32FTZ] in {
2553
+ def : Pat<(i16 (fp_to_sint f32:$a)), (CVT_s16_f32 $a, CvtRZI_FTZ)>;
2554
+ def : Pat<(i32 (fp_to_sint f32:$a)), (CVT_s32_f32 $a, CvtRZI_FTZ)>;
2555
+ def : Pat<(i64 (fp_to_sint f32:$a)), (CVT_s64_f32 $a, CvtRZI_FTZ)>;
2556
+ }
2557
+ def : Pat<(i1 (fp_to_sint f32:$a)), (SETP_b32ri $a, 0, CmpEQ)>;
2571
2558
def : Pat<(i16 (fp_to_sint f32:$a)), (CVT_s16_f32 $a, CvtRZI)>;
2572
- def : Pat<(i32 (fp_to_sint f32:$a)), (CVT_s32_f32 $a, CvtRZI_FTZ)>,
2573
- Requires<[doF32FTZ]>;
2574
2559
def : Pat<(i32 (fp_to_sint f32:$a)), (CVT_s32_f32 $a, CvtRZI)>;
2575
- def : Pat<(i64 (fp_to_sint f32:$a)), (CVT_s64_f32 $a, CvtRZI_FTZ)>,
2576
- Requires<[doF32FTZ]>;
2577
2560
def : Pat<(i64 (fp_to_sint f32:$a)), (CVT_s64_f32 $a, CvtRZI)>;
2578
2561
2579
2562
// f32 -> uint
2563
+ let Predicates = [doF32FTZ] in {
2564
+ def : Pat<(i16 (fp_to_uint f32:$a)), (CVT_u16_f32 $a, CvtRZI_FTZ)>;
2565
+ def : Pat<(i32 (fp_to_uint f32:$a)), (CVT_u32_f32 $a, CvtRZI_FTZ)>;
2566
+ def : Pat<(i64 (fp_to_uint f32:$a)), (CVT_u64_f32 $a, CvtRZI_FTZ)>;
2567
+ }
2580
2568
def : Pat<(i1 (fp_to_uint f32:$a)), (SETP_b32ri $a, 0, CmpEQ)>;
2581
- def : Pat<(i16 (fp_to_uint f32:$a)), (CVT_u16_f32 $a, CvtRZI_FTZ)>,
2582
- Requires<[doF32FTZ]>;
2583
2569
def : Pat<(i16 (fp_to_uint f32:$a)), (CVT_u16_f32 $a, CvtRZI)>;
2584
- def : Pat<(i32 (fp_to_uint f32:$a)), (CVT_u32_f32 $a, CvtRZI_FTZ)>,
2585
- Requires<[doF32FTZ]>;
2586
2570
def : Pat<(i32 (fp_to_uint f32:$a)), (CVT_u32_f32 $a, CvtRZI)>;
2587
- def : Pat<(i64 (fp_to_uint f32:$a)), (CVT_u64_f32 $a, CvtRZI_FTZ)>,
2588
- Requires<[doF32FTZ]>;
2589
2571
def : Pat<(i64 (fp_to_uint f32:$a)), (CVT_u64_f32 $a, CvtRZI)>;
2590
2572
2591
2573
// f64 -> sint
@@ -2707,28 +2689,24 @@ let hasSideEffects = false in {
2707
2689
2708
2690
// PTX 7.1 lets you avoid a temp register and just use _ as a "sink" for the
2709
2691
// unused high/low part.
2710
- def I32toI16H_Sink : NVPTXInst<(outs Int16Regs:$high),
2711
- (ins Int32Regs:$s),
2712
- "mov.b32 \t{{_, $high}}, $s;",
2713
- []>, Requires<[hasPTX<71>]>;
2714
- def I32toI16L_Sink : NVPTXInst<(outs Int16Regs:$low),
2715
- (ins Int32Regs:$s),
2716
- "mov.b32 \t{{$low, _}}, $s;",
2717
- []>, Requires<[hasPTX<71>]>;
2718
- def I64toI32H_Sink : NVPTXInst<(outs Int32Regs:$high),
2719
- (ins Int64Regs:$s),
2720
- "mov.b64 \t{{_, $high}}, $s;",
2721
- []>, Requires<[hasPTX<71>]>;
2722
- def I64toI32L_Sink : NVPTXInst<(outs Int32Regs:$low),
2723
- (ins Int64Regs:$s),
2724
- "mov.b64 \t{{$low, _}}, $s;",
2725
- []>, Requires<[hasPTX<71>]>;
2692
+ let Predicates = [hasPTX<71>] in {
2693
+ def I32toI16H_Sink : NVPTXInst<(outs Int16Regs:$high), (ins Int32Regs:$s),
2694
+ "mov.b32 \t{{_, $high}}, $s;", []>;
2695
+ def I32toI16L_Sink : NVPTXInst<(outs Int16Regs:$low), (ins Int32Regs:$s),
2696
+ "mov.b32 \t{{$low, _}}, $s;", []>;
2697
+ def I64toI32H_Sink : NVPTXInst<(outs Int32Regs:$high), (ins Int64Regs:$s),
2698
+ "mov.b64 \t{{_, $high}}, $s;", []>;
2699
+ def I64toI32L_Sink : NVPTXInst<(outs Int32Regs:$low), (ins Int64Regs:$s),
2700
+ "mov.b64 \t{{$low, _}}, $s;", []>;
2701
+ }
2726
2702
}
2727
2703
2728
- def : Pat<(i16 (trunc (srl i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>, Requires<[hasPTX<71>]>;
2729
- def : Pat<(i16 (trunc (sra i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>, Requires<[hasPTX<71>]>;
2730
- def : Pat<(i32 (trunc (srl i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>, Requires<[hasPTX<71>]>;
2731
- def : Pat<(i32 (trunc (sra i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>, Requires<[hasPTX<71>]>;
2704
+ let Predicates = [hasPTX<71>] in {
2705
+ def : Pat<(i16 (trunc (srl i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>;
2706
+ def : Pat<(i16 (trunc (sra i32:$s, (i32 16)))), (I32toI16H_Sink i32:$s)>;
2707
+ def : Pat<(i32 (trunc (srl i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>;
2708
+ def : Pat<(i32 (trunc (sra i64:$s, (i32 32)))), (I64toI32H_Sink i64:$s)>;
2709
+ }
2732
2710
2733
2711
// Fall back to the old way if we don't have PTX 7.1.
2734
2712
def : Pat<(i16 (trunc (srl i32:$s, (i32 16)))), (I32toI16H $s)>;
@@ -3061,29 +3039,19 @@ def stacksave :
3061
3039
SDNode<"NVPTXISD::STACKSAVE", SDTIntLeaf,
3062
3040
[SDNPHasChain, SDNPSideEffect]>;
3063
3041
3064
- def STACKRESTORE_32 :
3065
- BasicNVPTXInst<(outs), (ins Int32Regs:$ptr),
3066
- "stackrestore.u32",
3067
- [(stackrestore i32:$ptr)]>,
3068
- Requires<[hasPTX<73>, hasSM<52>]>;
3069
-
3070
- def STACKSAVE_32 :
3071
- BasicNVPTXInst<(outs Int32Regs:$dst), (ins),
3072
- "stacksave.u32",
3073
- [(set i32:$dst, (i32 stacksave))]>,
3074
- Requires<[hasPTX<73>, hasSM<52>]>;
3075
-
3076
- def STACKRESTORE_64 :
3077
- BasicNVPTXInst<(outs), (ins Int64Regs:$ptr),
3078
- "stackrestore.u64",
3079
- [(stackrestore i64:$ptr)]>,
3080
- Requires<[hasPTX<73>, hasSM<52>]>;
3081
-
3082
- def STACKSAVE_64 :
3083
- BasicNVPTXInst<(outs Int64Regs:$dst), (ins),
3084
- "stacksave.u64",
3085
- [(set i64:$dst, (i64 stacksave))]>,
3086
- Requires<[hasPTX<73>, hasSM<52>]>;
3042
+ let Predicates = [hasPTX<73>, hasSM<52>] in {
3043
+ foreach t = [I32RT, I64RT] in {
3044
+ def STACKRESTORE_ # t.Size :
3045
+ BasicNVPTXInst<(outs), (ins t.RC:$ptr),
3046
+ "stackrestore.u" # t.Size,
3047
+ [(stackrestore t.Ty:$ptr)]>;
3048
+
3049
+ def STACKSAVE_ # t.Size :
3050
+ BasicNVPTXInst<(outs t.RC:$dst), (ins),
3051
+ "stacksave.u" # t.Size,
3052
+ [(set t.Ty:$dst, (t.Ty stacksave))]>;
3053
+ }
3054
+ }
3087
3055
3088
3056
include "NVPTXIntrinsics.td"
3089
3057
@@ -3124,7 +3092,7 @@ def : Pat <
3124
3092
////////////////////////////////////////////////////////////////////////////////
3125
3093
3126
3094
class NVPTXFenceInst<string scope, string sem, Predicate ptx>:
3127
- NVPTXInst <(outs), (ins), "fence."#sem#"."#scope#";", [] >,
3095
+ BasicNVPTXInst <(outs), (ins), "fence."#sem#"."#scope>,
3128
3096
Requires<[ptx, hasSM<70>]>;
3129
3097
3130
3098
foreach scope = ["sys", "gpu", "cluster", "cta"] in {
0 commit comments