Skip to content

Commit 6f318d4

Browse files
committed
[NVPTX] Make minimum/maximum work on older GPUs
We want to use newer instructions if we are targeting sufficiently new SM and PTX versions. If we cannot use those newer instructions, let LLVM synthesize the sequence from more fundamental instructions.
1 parent cf79aba commit 6f318d4

File tree

2 files changed

+1366
-128
lines changed

2 files changed

+1366
-128
lines changed

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -268,16 +268,12 @@ multiclass ADD_SUB_INT_CARRY<string OpcStr, SDNode OpNode> {
268268
}
269269
}
270270

271-
// Template for instructions which take three fp64 or fp32 args. The
272-
// instructions are named "<OpcStr>.f<Width>" (e.g. "min.f64").
271+
// Template for minimum/maximum instructions.
273272
//
274273
// Also defines ftz (flush subnormal inputs and results to sign-preserving
275274
// zero) variants for fp32 functions.
276-
//
277-
// This multiclass should be used for nodes that cannot be folded into FMAs.
278-
// For nodes that can be folded into FMAs (i.e. adds and muls), use
279-
// F3_fma_component.
280-
multiclass F3<string OpcStr, SDNode OpNode> {
275+
multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
276+
if !not(NaN) then {
281277
def f64rr :
282278
NVPTXInst<(outs Float64Regs:$dst),
283279
(ins Float64Regs:$a, Float64Regs:$b),
@@ -288,6 +284,7 @@ multiclass F3<string OpcStr, SDNode OpNode> {
288284
(ins Float64Regs:$a, f64imm:$b),
289285
!strconcat(OpcStr, ".f64 \t$dst, $a, $b;"),
290286
[(set Float64Regs:$dst, (OpNode Float64Regs:$a, fpimm:$b))]>;
287+
}
291288
def f32rr_ftz :
292289
NVPTXInst<(outs Float32Regs:$dst),
293290
(ins Float32Regs:$a, Float32Regs:$b),
@@ -322,45 +319,45 @@ multiclass F3<string OpcStr, SDNode OpNode> {
322319
(ins Int16Regs:$a, Int16Regs:$b),
323320
!strconcat(OpcStr, ".f16 \t$dst, $a, $b;"),
324321
[(set Int16Regs:$dst, (OpNode (f16 Int16Regs:$a), (f16 Int16Regs:$b)))]>,
325-
Requires<[useFP16Math]>;
322+
Requires<[useFP16Math, hasSM<80>, hasPTX<70>]>;
326323

327324
def f16x2rr_ftz :
328325
NVPTXInst<(outs Int32Regs:$dst),
329326
(ins Int32Regs:$a, Int32Regs:$b),
330327
!strconcat(OpcStr, ".ftz.f16x2 \t$dst, $a, $b;"),
331328
[(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
332-
Requires<[useFP16Math, doF32FTZ]>;
329+
Requires<[useFP16Math, hasSM<80>, hasPTX<70>, doF32FTZ]>;
333330
def f16x2rr :
334331
NVPTXInst<(outs Int32Regs:$dst),
335332
(ins Int32Regs:$a, Int32Regs:$b),
336333
!strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
337334
[(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
338-
Requires<[useFP16Math]>;
335+
Requires<[useFP16Math, hasSM<80>, hasPTX<70>]>;
339336
def bf16rr_ftz :
340337
NVPTXInst<(outs Int16Regs:$dst),
341338
(ins Int16Regs:$a, Int16Regs:$b),
342339
!strconcat(OpcStr, ".ftz.bf16 \t$dst, $a, $b;"),
343340
[(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
344-
Requires<[hasBF16Math, doF32FTZ]>;
341+
Requires<[hasBF16Math, doF32FTZ, hasSM<80>, hasPTX<70>]>;
345342
def bf16rr :
346343
NVPTXInst<(outs Int16Regs:$dst),
347344
(ins Int16Regs:$a, Int16Regs:$b),
348345
!strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"),
349346
[(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
350-
Requires<[hasBF16Math]>;
347+
Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
351348

352349
def bf16x2rr_ftz :
353350
NVPTXInst<(outs Int32Regs:$dst),
354351
(ins Int32Regs:$a, Int32Regs:$b),
355352
!strconcat(OpcStr, ".ftz.bf16x2 \t$dst, $a, $b;"),
356353
[(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
357-
Requires<[hasBF16Math, doF32FTZ]>;
354+
Requires<[hasBF16Math, hasSM<80>, hasPTX<70>, doF32FTZ]>;
358355
def bf16x2rr :
359356
NVPTXInst<(outs Int32Regs:$dst),
360357
(ins Int32Regs:$a, Int32Regs:$b),
361358
!strconcat(OpcStr, ".bf16x2 \t$dst, $a, $b;"),
362359
[(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
363-
Requires<[hasBF16Math]>;
360+
Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
364361
}
365362

366363
// Template for instructions which take three FP args. The
@@ -1178,11 +1175,10 @@ defm FADD : F3_fma_component<"add", fadd>;
11781175
defm FSUB : F3_fma_component<"sub", fsub>;
11791176
defm FMUL : F3_fma_component<"mul", fmul>;
11801177

1181-
defm FMIN : F3<"min", fminnum>;
1182-
defm FMAX : F3<"max", fmaxnum>;
1183-
// Note: min.NaN.f64 and max.NaN.f64 do not actually exist.
1184-
defm FMINNAN : F3<"min.NaN", fminimum>;
1185-
defm FMAXNAN : F3<"max.NaN", fmaximum>;
1178+
defm FMIN : FMINIMUMMAXIMUM<"min", /* NaN */ false, fminnum>;
1179+
defm FMAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum>;
1180+
defm FMINNAN : FMINIMUMMAXIMUM<"min.NaN", /* NaN */ true, fminimum>;
1181+
defm FMAXNAN : FMINIMUMMAXIMUM<"max.NaN", /* NaN */ true, fmaximum>;
11861182

11871183
defm FABS : F2<"abs", fabs>;
11881184
defm FNEG : F2<"neg", fneg>;

0 commit comments

Comments
 (0)