Skip to content

Commit 3ba339b

Browse files
authored
[NVPTX] Improve support for {ex2,lg2}.approx (#120519)
- Add support for `@llvm.exp2()`: - LLVM: `float` -> PTX: `ex2.approx{.ftz}.f32` - LLVM: `half` -> PTX: `ex2.approx.f16` - LLVM: `<2 x half>` -> PTX: `ex2.approx.f16x2` - LLVM: `bfloat` -> PTX: `ex2.approx.ftz.bf16` - LLVM: `<2 x bfloat>` -> PTX: `ex2.approx.ftz.bf16x2` - Any operations with non-native vector widths are expanded. On targets not supporting f16/bf16, values are promoted to f32. - Add *CONDITIONAL* support for `@llvm.log2()` [^1]: - LLVM: `float` -> PTX: `lg2.approx{.ftz}.f32` - Support for f16/bf16 is emulated by promoting values to f32. [1]: CUDA implements `exp2()` with `ex2.approx` but `log2()` is implemented differently, so this is off by default. To enable, use the flag `-nvptx-approx-log2f32`.
1 parent 99d40fe commit 3ba339b

File tree

8 files changed

+808
-13
lines changed

8 files changed

+808
-13
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,13 @@ static cl::opt<bool> UsePrecSqrtF32(
9494
cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
9595
cl::init(true));
9696

97+
/// Whereas CUDA's implementation (see libdevice) uses ex2.approx for exp2(), it
98+
/// does NOT use lg2.approx for log2, so this is disabled by default.
99+
static cl::opt<bool> UseApproxLog2F32(
100+
"nvptx-approx-log2f32",
101+
cl::desc("NVPTX Specific: whether to use lg2.approx for log2"),
102+
cl::init(false));
103+
97104
static cl::opt<bool> ForceMinByValParamAlign(
98105
"nvptx-force-min-byval-param-align", cl::Hidden,
99106
cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval"
@@ -529,6 +536,9 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
529536
case ISD::FMINIMUM:
530537
IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
531538
break;
539+
case ISD::FEXP2:
540+
IsOpSupported &= STI.getSmVersion() >= 75 && STI.getPTXVersion() >= 70;
541+
break;
532542
}
533543
setOperationAction(Op, VT, IsOpSupported ? Action : NoF16Action);
534544
};
@@ -959,7 +969,26 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
959969
setOperationAction(ISD::CopyToReg, MVT::i128, Custom);
960970
setOperationAction(ISD::CopyFromReg, MVT::i128, Custom);
961971

962-
// No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
972+
// FEXP2 support:
973+
// - f32
974+
// - f16/f16x2 (sm_70+, PTX 7.0+)
975+
// - bf16/bf16x2 (sm_90+, PTX 7.8+)
976+
// When f16/bf16 types aren't supported, they are promoted/expanded to f32.
977+
setOperationAction(ISD::FEXP2, MVT::f32, Legal);
978+
setFP16OperationAction(ISD::FEXP2, MVT::f16, Legal, Promote);
979+
setFP16OperationAction(ISD::FEXP2, MVT::v2f16, Legal, Expand);
980+
setBF16OperationAction(ISD::FEXP2, MVT::bf16, Legal, Promote);
981+
setBF16OperationAction(ISD::FEXP2, MVT::v2bf16, Legal, Expand);
982+
983+
// FLOG2 supports f32 only
984+
// f16/bf16 types aren't supported, but they are promoted/expanded to f32.
985+
if (UseApproxLog2F32) {
986+
setOperationAction(ISD::FLOG2, MVT::f32, Legal);
987+
setOperationPromotedToType(ISD::FLOG2, MVT::f16, MVT::f32);
988+
setOperationPromotedToType(ISD::FLOG2, MVT::bf16, MVT::f32);
989+
setOperationAction(ISD::FLOG2, {MVT::v2f16, MVT::v2bf16}, Expand);
990+
}
991+
963992
// No FPOW or FREM in PTX.
964993

965994
// Now deduce the information based on the above mentioned

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,18 @@ multiclass F2_Support_Half<string OpcStr, SDNode OpNode> {
569569

570570
}
571571

572+
// Variant where only .ftz.bf16 is supported.
573+
multiclass F2_Support_Half_BF<string OpcStr, SDNode OpNode> {
574+
def bf16_ftz : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a),
575+
OpcStr # ".ftz.bf16 \t$dst, $a;",
576+
[(set bf16:$dst, (OpNode bf16:$a))]>,
577+
Requires<[hasSM<90>, hasPTX<78>]>;
578+
def bf16x2_ftz: NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a),
579+
OpcStr # ".ftz.bf16x2 \t$dst, $a;",
580+
[(set v2bf16:$dst, (OpNode v2bf16:$a))]>,
581+
Requires<[hasSM<90>, hasPTX<78>]>;
582+
}
583+
572584
//===----------------------------------------------------------------------===//
573585
// NVPTX Instructions.
574586
//===----------------------------------------------------------------------===//
@@ -1183,6 +1195,8 @@ defm FNEG_H: F2_Support_Half<"neg", fneg>;
11831195

11841196
defm FSQRT : F2<"sqrt.rn", fsqrt>;
11851197

1198+
defm FEXP2_H: F2_Support_Half_BF<"ex2.approx", fexp2>;
1199+
11861200
//
11871201
// F16 NEG
11881202
//

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,18 +1304,33 @@ def INT_NVVM_EX2_APPROX_F : F_MATH_1<"ex2.approx.f32 \t$dst, $src0;",
13041304
Float32Regs, Float32Regs, int_nvvm_ex2_approx_f>;
13051305
def INT_NVVM_EX2_APPROX_D : F_MATH_1<"ex2.approx.f64 \t$dst, $src0;",
13061306
Float64Regs, Float64Regs, int_nvvm_ex2_approx_d>;
1307+
13071308
def INT_NVVM_EX2_APPROX_F16 : F_MATH_1<"ex2.approx.f16 \t$dst, $src0;",
13081309
Int16Regs, Int16Regs, int_nvvm_ex2_approx_f16, [hasPTX<70>, hasSM<75>]>;
13091310
def INT_NVVM_EX2_APPROX_F16X2 : F_MATH_1<"ex2.approx.f16x2 \t$dst, $src0;",
13101311
Int32Regs, Int32Regs, int_nvvm_ex2_approx_f16x2, [hasPTX<70>, hasSM<75>]>;
13111312

1313+
def : Pat<(fexp2 f32:$a),
1314+
(INT_NVVM_EX2_APPROX_FTZ_F $a)>, Requires<[doF32FTZ]>;
1315+
def : Pat<(fexp2 f32:$a),
1316+
(INT_NVVM_EX2_APPROX_F $a)>, Requires<[doNoF32FTZ]>;
1317+
def : Pat<(fexp2 f16:$a),
1318+
(INT_NVVM_EX2_APPROX_F16 $a)>, Requires<[useFP16Math]>;
1319+
def : Pat<(fexp2 v2f16:$a),
1320+
(INT_NVVM_EX2_APPROX_F16X2 $a)>, Requires<[useFP16Math]>;
1321+
13121322
def INT_NVVM_LG2_APPROX_FTZ_F : F_MATH_1<"lg2.approx.ftz.f32 \t$dst, $src0;",
13131323
Float32Regs, Float32Regs, int_nvvm_lg2_approx_ftz_f>;
13141324
def INT_NVVM_LG2_APPROX_F : F_MATH_1<"lg2.approx.f32 \t$dst, $src0;",
13151325
Float32Regs, Float32Regs, int_nvvm_lg2_approx_f>;
13161326
def INT_NVVM_LG2_APPROX_D : F_MATH_1<"lg2.approx.f64 \t$dst, $src0;",
13171327
Float64Regs, Float64Regs, int_nvvm_lg2_approx_d>;
13181328

1329+
def : Pat<(flog2 f32:$a), (INT_NVVM_LG2_APPROX_FTZ_F $a)>,
1330+
Requires<[doF32FTZ]>;
1331+
def : Pat<(flog2 f32:$a), (INT_NVVM_LG2_APPROX_F $a)>,
1332+
Requires<[doNoF32FTZ]>;
1333+
13191334
//
13201335
// Sin Cos
13211336
//

llvm/test/CodeGen/NVPTX/f16-ex2.ll

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,37 @@
1-
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_75 -mattr=+ptx70 | FileCheck %s
2-
; RUN: %if ptxas-11.0 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_75 -mattr=+ptx70 | %ptxas-verify -arch=sm_75 %}
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mcpu=sm_75 -mattr=+ptx70 | FileCheck --check-prefixes=CHECK-FP16 %s
3+
; RUN: %if ptxas-11.0 %{ llc < %s -mcpu=sm_75 -mattr=+ptx70 | %ptxas-verify -arch=sm_75 %}
4+
target triple = "nvptx64-nvidia-cuda"
35

46
declare half @llvm.nvvm.ex2.approx.f16(half)
57
declare <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half>)
68

7-
; CHECK-LABEL: exp2_half
8-
define half @exp2_half(half %0) {
9-
; CHECK-NOT: call
10-
; CHECK: ex2.approx.f16
11-
%res = call half @llvm.nvvm.ex2.approx.f16(half %0);
9+
; CHECK-LABEL: ex2_half
10+
define half @ex2_half(half %0) {
11+
; CHECK-FP16-LABEL: ex2_half(
12+
; CHECK-FP16: {
13+
; CHECK-FP16-NEXT: .reg .b16 %rs<3>;
14+
; CHECK-FP16-EMPTY:
15+
; CHECK-FP16-NEXT: // %bb.0:
16+
; CHECK-FP16-NEXT: ld.param.b16 %rs1, [ex2_half_param_0];
17+
; CHECK-FP16-NEXT: ex2.approx.f16 %rs2, %rs1;
18+
; CHECK-FP16-NEXT: st.param.b16 [func_retval0], %rs2;
19+
; CHECK-FP16-NEXT: ret;
20+
%res = call half @llvm.nvvm.ex2.approx.f16(half %0)
1221
ret half %res
1322
}
1423

15-
; CHECK-LABEL: exp2_2xhalf
16-
define <2 x half> @exp2_2xhalf(<2 x half> %0) {
17-
; CHECK-NOT: call
18-
; CHECK: ex2.approx.f16x2
19-
%res = call <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half> %0);
24+
; CHECK-LABEL: ex2_2xhalf
25+
define <2 x half> @ex2_2xhalf(<2 x half> %0) {
26+
; CHECK-FP16-LABEL: ex2_2xhalf(
27+
; CHECK-FP16: {
28+
; CHECK-FP16-NEXT: .reg .b32 %r<3>;
29+
; CHECK-FP16-EMPTY:
30+
; CHECK-FP16-NEXT: // %bb.0:
31+
; CHECK-FP16-NEXT: ld.param.b32 %r1, [ex2_2xhalf_param_0];
32+
; CHECK-FP16-NEXT: ex2.approx.f16x2 %r2, %r1;
33+
; CHECK-FP16-NEXT: st.param.b32 [func_retval0], %r2;
34+
; CHECK-FP16-NEXT: ret;
35+
%res = call <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half> %0)
2036
ret <2 x half> %res
2137
}

llvm/test/CodeGen/NVPTX/f32-ex2.ll

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mcpu=sm_50 -mattr=+ptx32 | FileCheck --check-prefixes=CHECK %s
3+
; RUN: %if ptxas-11.0 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_50 -mattr=+ptx32 | %ptxas-verify -arch=sm_50 %}
4+
target triple = "nvptx-nvidia-cuda"
5+
6+
declare float @llvm.nvvm.ex2.approx.f(float)
7+
8+
; CHECK-LABEL: ex2_float
9+
define float @ex2_float(float %0) {
10+
; CHECK-LABEL: ex2_float(
11+
; CHECK: {
12+
; CHECK-NEXT: .reg .f32 %f<3>;
13+
; CHECK-EMPTY:
14+
; CHECK-NEXT: // %bb.0:
15+
; CHECK-NEXT: ld.param.f32 %f1, [ex2_float_param_0];
16+
; CHECK-NEXT: ex2.approx.f32 %f2, %f1;
17+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
18+
; CHECK-NEXT: ret;
19+
%res = call float @llvm.nvvm.ex2.approx.f(float %0)
20+
ret float %res
21+
}
22+
23+
; CHECK-LABEL: ex2_float_ftz
24+
define float @ex2_float_ftz(float %0) {
25+
; CHECK-LABEL: ex2_float_ftz(
26+
; CHECK: {
27+
; CHECK-NEXT: .reg .f32 %f<3>;
28+
; CHECK-EMPTY:
29+
; CHECK-NEXT: // %bb.0:
30+
; CHECK-NEXT: ld.param.f32 %f1, [ex2_float_ftz_param_0];
31+
; CHECK-NEXT: ex2.approx.ftz.f32 %f2, %f1;
32+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
33+
; CHECK-NEXT: ret;
34+
%res = call float @llvm.nvvm.ex2.approx.ftz.f(float %0)
35+
ret float %res
36+
}

llvm/test/CodeGen/NVPTX/f32-lg2.ll

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mcpu=sm_20 -mattr=+ptx32 | FileCheck --check-prefixes=CHECK %s
3+
; RUN: %if ptxas %{ llc < %s -mcpu=sm_20 -mattr=+ptx32 | %ptxas-verify %}
4+
target triple = "nvptx-nvidia-cuda"
5+
6+
declare float @llvm.nvvm.lg2.approx.f(float)
7+
declare float @llvm.nvvm.lg2.approx.ftz.f(float)
8+
9+
; CHECK-LABEL: lg2_float
10+
define float @lg2_float(float %0) {
11+
; CHECK-LABEL: lg2_float(
12+
; CHECK: {
13+
; CHECK-NEXT: .reg .f32 %f<3>;
14+
; CHECK-EMPTY:
15+
; CHECK-NEXT: // %bb.0:
16+
; CHECK-NEXT: ld.param.f32 %f1, [lg2_float_param_0];
17+
; CHECK-NEXT: lg2.approx.f32 %f2, %f1;
18+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
19+
; CHECK-NEXT: ret;
20+
%res = call float @llvm.nvvm.lg2.approx.f(float %0)
21+
ret float %res
22+
}
23+
24+
; CHECK-LABEL: lg2_float_ftz
25+
define float @lg2_float_ftz(float %0) {
26+
; CHECK-LABEL: lg2_float_ftz(
27+
; CHECK: {
28+
; CHECK-NEXT: .reg .f32 %f<3>;
29+
; CHECK-EMPTY:
30+
; CHECK-NEXT: // %bb.0:
31+
; CHECK-NEXT: ld.param.f32 %f1, [lg2_float_ftz_param_0];
32+
; CHECK-NEXT: lg2.approx.ftz.f32 %f2, %f1;
33+
; CHECK-NEXT: st.param.f32 [func_retval0], %f2;
34+
; CHECK-NEXT: ret;
35+
%res = call float @llvm.nvvm.lg2.approx.ftz.f(float %0)
36+
ret float %res
37+
}

0 commit comments

Comments
 (0)