Skip to content

Commit df60805

Browse files
authored
[NVPTX] Improve support for rsqrt.approx (#89417)
Complete support for rsqrt.approx with rsqrt.approx.f64 ([PTX ISA 9.7.3.17. Floating Point Instructions: rsqrt.approx.ftz.f64](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64)). Additionally, add support for folding `sqrt` into `rsqrt`, with an optional flag to disable.
1 parent f426be1 commit df60805

File tree

7 files changed

+145
-0
lines changed

7 files changed

+145
-0
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,8 @@ let TargetPrefix = "nvvm" in {
10031003

10041004
def int_nvvm_rsqrt_approx_ftz_f : ClangBuiltin<"__nvvm_rsqrt_approx_ftz_f">,
10051005
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
1006+
def int_nvvm_rsqrt_approx_ftz_d : ClangBuiltin<"__nvvm_rsqrt_approx_ftz_d">,
1007+
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty], [IntrNoMem]>;
10061008
def int_nvvm_rsqrt_approx_f : ClangBuiltin<"__nvvm_rsqrt_approx_f">,
10071009
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
10081010
def int_nvvm_rsqrt_approx_d : ClangBuiltin<"__nvvm_rsqrt_approx_d">,

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ using namespace llvm;
3030
#define DEBUG_TYPE "nvptx-isel"
3131
#define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection"
3232

33+
static cl::opt<bool>
34+
EnableRsqrtOpt("nvptx-rsqrt-approx-opt", cl::init(true), cl::Hidden,
35+
cl::desc("Enable reciprocal sqrt optimization"));
36+
3337
/// createNVPTXISelDag - This pass converts a legalized DAG into a
3438
/// NVPTX-specific DAG, ready for instruction scheduling.
3539
FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM,
@@ -74,6 +78,8 @@ bool NVPTXDAGToDAGISel::allowUnsafeFPMath() const {
7478
return TL->allowUnsafeFPMath(*MF);
7579
}
7680

81+
bool NVPTXDAGToDAGISel::doRsqrtOpt() const { return EnableRsqrtOpt; }
82+
7783
/// Select - Select instructions not customized! Used for
7884
/// expanded, promoted and normal instructions.
7985
void NVPTXDAGToDAGISel::Select(SDNode *N) {

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
3636
bool useF32FTZ() const;
3737
bool allowFMA() const;
3838
bool allowUnsafeFPMath() const;
39+
bool doRsqrtOpt() const;
3940

4041
public:
4142
static char ID;

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def hasLDU : Predicate<"Subtarget->hasLDU()">;
142142

143143
def doF32FTZ : Predicate<"useF32FTZ()">;
144144
def doNoF32FTZ : Predicate<"!useF32FTZ()">;
145+
def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
145146

146147
def doMulWide : Predicate<"doMulWide">;
147148

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,11 +1171,36 @@ def : Pat<(int_nvvm_sqrt_f Float32Regs:$a),
11711171
def INT_NVVM_RSQRT_APPROX_FTZ_F
11721172
: F_MATH_1<"rsqrt.approx.ftz.f32 \t$dst, $src0;", Float32Regs, Float32Regs,
11731173
int_nvvm_rsqrt_approx_ftz_f>;
1174+
def INT_NVVM_RSQRT_APPROX_FTZ_D
1175+
: F_MATH_1<"rsqrt.approx.ftz.f64 \t$dst, $src0;", Float64Regs, Float64Regs,
1176+
int_nvvm_rsqrt_approx_ftz_d>;
1177+
11741178
def INT_NVVM_RSQRT_APPROX_F : F_MATH_1<"rsqrt.approx.f32 \t$dst, $src0;",
11751179
Float32Regs, Float32Regs, int_nvvm_rsqrt_approx_f>;
11761180
def INT_NVVM_RSQRT_APPROX_D : F_MATH_1<"rsqrt.approx.f64 \t$dst, $src0;",
11771181
Float64Regs, Float64Regs, int_nvvm_rsqrt_approx_d>;
11781182

1183+
// 1.0f / sqrt_approx -> rsqrt_approx
1184+
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_f Float32Regs:$a)),
1185+
(INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
1186+
Requires<[doRsqrtOpt]>;
1187+
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_ftz_f Float32Regs:$a)),
1188+
(INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
1189+
Requires<[doRsqrtOpt]>;
1190+
// same for int_nvvm_sqrt_f when non-precision sqrt is requested
1191+
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)),
1192+
(INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
1193+
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
1194+
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)),
1195+
(INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
1196+
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
1197+
1198+
def: Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)),
1199+
(INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
1200+
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
1201+
def: Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)),
1202+
(INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
1203+
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
11791204
//
11801205
// Add
11811206
//

llvm/test/CodeGen/NVPTX/rsqrt-opt.ll

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
; RUN: llc < %s -march=nvptx64 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-NOOPT
2+
; RUN: llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-OPT
3+
; RUN: llc < %s -march=nvptx64 -nvptx-rsqrt-approx-opt=0 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-NOOPT,CHECK-SQRT-NOOPT
4+
;
5+
; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
6+
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | %ptxas-verify %}
7+
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-rsqrt-approx-opt=0 | %ptxas-verify %}
8+
9+
10+
; CHECK-LABEL: .func{{.*}}test1
11+
define float @test1(float %in) local_unnamed_addr {
12+
; CHECK-APPROX-OPT: rsqrt.approx.f32
13+
; CHECK-APPROX-NOOPT: sqrt.approx.f32
14+
; CHECK-APPROX-NOOPT-NEXT: rcp.rn.f32
15+
%sqrt = tail call float @llvm.nvvm.sqrt.approx.f(float %in)
16+
%rsqrt = fdiv float 1.0, %sqrt
17+
ret float %rsqrt
18+
}
19+
; CHECK-LABEL: .func{{.*}}test2
20+
define float @test2(float %in) local_unnamed_addr {
21+
; CHECK-APPROX-OPT: rsqrt.approx.ftz.f32
22+
; CHECK-APPROX-NOOPT: sqrt.approx.ftz.f32
23+
; CHECK-APPROX-NOOPT-NEXT: rcp.rn.f32
24+
%sqrt = tail call float @llvm.nvvm.sqrt.approx.ftz.f(float %in)
25+
%rsqrt = fdiv float 1.0, %sqrt
26+
ret float %rsqrt
27+
}
28+
29+
; CHECK-LABEL: .func{{.*}}test3
30+
define float @test3(float %in) local_unnamed_addr {
31+
; CHECK-SQRT-OPT: rsqrt.approx.f32
32+
; CHECK-SQRT-NOOPT: sqrt.rn.f32
33+
; CHECK-SQRT-NOOPT-NEXT: rcp.rn.f32
34+
%sqrt = tail call float @llvm.nvvm.sqrt.f(float %in)
35+
%rsqrt = fdiv float 1.0, %sqrt
36+
ret float %rsqrt
37+
}
38+
39+
; CHECK-LABEL: .func{{.*}}test4
40+
define float @test4(float %in) local_unnamed_addr #0 {
41+
; CHECK-SQRT-OPT: rsqrt.approx.ftz.f32
42+
; CHECK-SQRT-NOOPT: sqrt.rn.ftz.f32
43+
; CHECK-SQRT-NOOPT-NEXT: rcp.rn.ftz.f32
44+
%sqrt = tail call float @llvm.nvvm.sqrt.f(float %in)
45+
%rsqrt = fdiv float 1.0, %sqrt
46+
ret float %rsqrt
47+
}
48+
49+
; CHECK-LABEL: .func{{.*}}test5
50+
define float @test5(float %in) local_unnamed_addr {
51+
; CHECK-SQRT-OPT: rsqrt.approx.f32
52+
; CHECK-SQRT-NOOPT: sqrt.rn.f32
53+
; CHECK-SQRT-NOOPT-NEXT: rcp.rn.f32
54+
%sqrt = tail call float @llvm.sqrt.f32(float %in)
55+
%rsqrt = fdiv float 1.0, %sqrt
56+
ret float %rsqrt
57+
}
58+
59+
; CHECK-LABEL: .func{{.*}}test6
60+
define float @test6(float %in) local_unnamed_addr #0 {
61+
; CHECK-SQRT-OPT: rsqrt.approx.ftz.f32
62+
; CHECK-SQRT-NOOPT: sqrt.rn.ftz.f32
63+
; CHECK-SQRT-NOOPT-NEXT: rcp.rn.ftz.f32
64+
%sqrt = tail call float @llvm.sqrt.f32(float %in)
65+
%rsqrt = fdiv float 1.0, %sqrt
66+
ret float %rsqrt
67+
}
68+
69+
70+
declare float @llvm.nvvm.sqrt.f(float)
71+
declare float @llvm.nvvm.sqrt.approx.f(float)
72+
declare float @llvm.nvvm.sqrt.approx.ftz.f(float)
73+
declare float @llvm.sqrt.f32(float)
74+
75+
attributes #0 = { "denormal-fp-math-f32" = "preserve-sign" }

llvm/test/CodeGen/NVPTX/rsqrt.ll

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
; RUN: llc < %s -march=nvptx64 | FileCheck %s
2+
; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
3+
4+
; CHECK-LABEL: .func{{.*}}test1
5+
define float @test1(float %in) local_unnamed_addr {
6+
; CHECK: rsqrt.approx.f32
7+
%call = call float @llvm.nvvm.rsqrt.approx.f(float %in)
8+
ret float %call
9+
}
10+
11+
; CHECK-LABEL: .func{{.*}}test2
12+
define double @test2(double %in) local_unnamed_addr {
13+
; CHECK: rsqrt.approx.f64
14+
%call = call double @llvm.nvvm.rsqrt.approx.d(double %in)
15+
ret double %call
16+
}
17+
18+
; CHECK-LABEL: .func{{.*}}test3
19+
define float @test3(float %in) local_unnamed_addr {
20+
; CHECK: rsqrt.approx.ftz.f32
21+
%call = tail call float @llvm.nvvm.rsqrt.approx.ftz.f(float %in)
22+
ret float %call
23+
}
24+
25+
; CHECK-LABEL: .func{{.*}}test4
26+
define double @test4(double %in) local_unnamed_addr {
27+
; CHECK: rsqrt.approx.ftz.f64
28+
%call = tail call double @llvm.nvvm.rsqrt.approx.ftz.d(double %in)
29+
ret double %call
30+
}
31+
32+
declare float @llvm.nvvm.rsqrt.approx.ftz.f(float)
33+
declare double @llvm.nvvm.rsqrt.approx.ftz.d(double)
34+
declare float @llvm.nvvm.rsqrt.approx.f(float)
35+
declare double @llvm.nvvm.rsqrt.approx.d(double)

0 commit comments

Comments
 (0)