Skip to content

Commit 52238e1

Browse files
authored
[SYCL][NVPTX] Set default fdiv and sqrt for llvm.fpbuiltin (#16714)
AltMathLibrary is lacking implementation for llvm.fpbuiltin intrinsics for NVPTX target. This patch adds type-dependent mapping for llvm.fpbuiltin.fdiv with max-error > 2.0 and llvm.fpbuiltin.sqrt with max-error > 1.0 on nvvm intrinsics: fp32 scalar @llvm.fpbuiltin.fdiv -> @llvm.nvvm.div.approx.f fp32 scalar @llvm.fpbuiltin.sqrt -> @llvm.nvvm.sqrt.approx.f vector or non-fp32 scalar llvm.fpbuiltin.fdiv -> fdiv vector or non-fp32 scalar llvm.fpbuiltin.sqrt -> llvm.sqrt Additionally it maps max-error=0.5 fpbuiltin.fadd, fpbuiltin.fsub. fpbuiltin.fmul, fpbuiltin.fdiv, fpbuiltin.frem, fpbuiltin.sqrt and fpbuiltin.ldexp intrinsic functions of LLVM's math operations or https://llvm.org/docs/LangRef.html#standard-c-c-library-intrinsics TODO in future patches: - add preservation of debug info in FPBuiltinFnSelection; - moved tests from CodeGen to Transform - move pass to new pass manager Signed-off-by: Sidorov, Dmitry <[email protected]> --------- Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent ddea941 commit 52238e1

File tree

3 files changed

+372
-4
lines changed

3 files changed

+372
-4
lines changed

llvm/lib/Transforms/Scalar/FPBuiltinFnSelection.cpp

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/IR/IRBuilder.h"
1919
#include "llvm/IR/InstIterator.h"
2020
#include "llvm/IR/IntrinsicInst.h"
21+
#include "llvm/IR/IntrinsicsNVPTX.h"
2122
#include "llvm/InitializePasses.h"
2223
#include "llvm/Support/FormatVariadic.h"
2324

@@ -106,6 +107,51 @@ static bool replaceWithLLVMIR(FPBuiltinIntrinsic &BuiltinCall) {
106107
return true;
107108
}
108109

110+
// This function lowers llvm.fpbuiltin. intrinsic functions with max-error
111+
// attribute to the appropriate nvvm approximate intrinsics if it's possible.
112+
// If it's not possible - fallback to instruction or standard C/C++ library LLVM
113+
// intrinsic.
114+
static bool
115+
replaceWithApproxNVPTXCallsOrFallback(FPBuiltinIntrinsic &BuiltinCall,
116+
std::optional<float> Accuracy) {
117+
IRBuilder<> IRBuilder(&BuiltinCall);
118+
SmallVector<Value *> Args(BuiltinCall.args());
119+
Value *Replacement = nullptr;
120+
auto *Type = BuiltinCall.getType();
121+
// For now only add lowering for fdiv and sqrt. Yet nvvm intrinsics have
122+
// approximate variants for sin, cos, exp2 and log2.
123+
// For vector fpbuiltins for NVPTX target we don't have nvvm intrinsics,
124+
// fallback to instruction or standard C/C++ library LLVM intrinsic. Also
125+
// nvvm fdiv and sqrt intrisics support only float type, so fallback in this
126+
// case as well.
127+
switch (BuiltinCall.getIntrinsicID()) {
128+
case Intrinsic::fpbuiltin_fdiv:
129+
if (Accuracy.value() < 2.0)
130+
return false;
131+
if (Type->isVectorTy() || !Type->getScalarType()->isFloatTy())
132+
return replaceWithLLVMIR(BuiltinCall);
133+
Replacement =
134+
IRBuilder.CreateIntrinsic(Type, Intrinsic::nvvm_div_approx_f, Args);
135+
break;
136+
case Intrinsic::fpbuiltin_sqrt:
137+
if (Accuracy.value() < 1.0)
138+
return false;
139+
if (Type->isVectorTy() || !Type->getScalarType()->isFloatTy())
140+
return replaceWithLLVMIR(BuiltinCall);
141+
Replacement =
142+
IRBuilder.CreateIntrinsic(Type, Intrinsic::nvvm_sqrt_approx_f, Args);
143+
break;
144+
default:
145+
return false;
146+
}
147+
BuiltinCall.replaceAllUsesWith(Replacement);
148+
cast<Instruction>(Replacement)->copyFastMathFlags(&BuiltinCall);
149+
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
150+
<< BuiltinCall.getCalledFunction()->getName()
151+
<< "` with equivalent IR. \n `");
152+
return true;
153+
}
154+
109155
static bool selectFnForFPBuiltinCalls(const TargetLibraryInfo &TLI,
110156
const TargetTransformInfo &TTI,
111157
FPBuiltinIntrinsic &BuiltinCall) {
@@ -136,10 +182,11 @@ static bool selectFnForFPBuiltinCalls(const TargetLibraryInfo &TLI,
136182
return replaceWithLLVMIR(BuiltinCall);
137183

138184
// Several functions for "sycl" and "cuda" requires "0.5" accuracy levels,
139-
// which means correctly rounded results. For now x86 host AltMathLibrary
140-
// doesn't have such ability. For such accuracy level, the fpbuiltins
141-
// should be replaced by equivalent IR operation or llvmbuiltins.
142-
if (T.isX86() && BuiltinCall.getRequiredAccuracy().value() == 0.5) {
185+
// which means correctly rounded results. For now x86 host and NVPTX
186+
// AltMathLibrary doesn't have such ability. For such accuracy level, the
187+
// fpbuiltins should be replaced by equivalent IR operation or llvmbuiltins.
188+
if ((T.isX86() || T.isNVPTX()) &&
189+
BuiltinCall.getRequiredAccuracy().value() == 0.5) {
143190
switch (BuiltinCall.getIntrinsicID()) {
144191
case Intrinsic::fpbuiltin_fadd:
145192
case Intrinsic::fpbuiltin_fsub:
@@ -154,6 +201,14 @@ static bool selectFnForFPBuiltinCalls(const TargetLibraryInfo &TLI,
154201
}
155202
}
156203

204+
// AltMathLibrary don't have implementation for CUDA approximate precision
205+
// builtins. Lets map them on NVPTX intrinsics. If no appropriate intrinsics
206+
// are known - skip to emit an error.
207+
if (T.isNVPTX() && BuiltinCall.getRequiredAccuracy().value() > 0.5)
208+
if (replaceWithApproxNVPTXCallsOrFallback(
209+
BuiltinCall, BuiltinCall.getRequiredAccuracy()))
210+
return true;
211+
157212
/// Call TLI to select a function implementation to call
158213
StringRef ImplName = TLI.selectFPBuiltinImplementation(&BuiltinCall);
159214
if (ImplName.empty()) {
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
; RUN: opt -fpbuiltin-fn-selection -S < %s | FileCheck %s
2+
3+
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
4+
target triple = "nvptx64-nvidia-cuda"
5+
6+
; CHECK-LABEL: @test_fdiv
7+
; CHECK: %{{.*}} = call float @llvm.nvvm.div.approx.f(float %{{.*}}, float %{{.*}})
8+
; CHECK: %{{.*}} = fdiv <2 x float> %{{.*}}, %{{.*}}
9+
define void @test_fdiv(float %d1, <2 x float> %v2d1,
10+
float %d2, <2 x float> %v2d2) {
11+
entry:
12+
%t0 = call float @llvm.fpbuiltin.fdiv.f32(float %d1, float %d2) #0
13+
%t1 = call <2 x float> @llvm.fpbuiltin.fdiv.v2f32(<2 x float> %v2d1, <2 x float> %v2d2) #0
14+
ret void
15+
}
16+
17+
; CHECK-LABEL: @test_fdiv_fast
18+
; CHECK: %{{.*}} = call fast float @llvm.nvvm.div.approx.f(float %{{.*}}, float %{{.*}})
19+
; CHECK: %{{.*}} = fdiv fast <2 x float> %{{.*}}, %{{.*}}
20+
define void @test_fdiv_fast(float %d1, <2 x float> %v2d1,
21+
float %d2, <2 x float> %v2d2) {
22+
entry:
23+
%t0 = call fast float @llvm.fpbuiltin.fdiv.f32(float %d1, float %d2) #0
24+
%t1 = call fast <2 x float> @llvm.fpbuiltin.fdiv.v2f32(<2 x float> %v2d1, <2 x float> %v2d2) #0
25+
ret void
26+
}
27+
28+
; CHECK-LABEL: @test_fdiv_max_error
29+
; CHECK: %{{.*}} = call float @llvm.nvvm.div.approx.f(float %{{.*}}, float %{{.*}})
30+
; CHECK: %{{.*}} = fdiv <2 x float> %{{.*}}, %{{.*}}
31+
define void @test_fdiv_max_error(float %d1, <2 x float> %v2d1,
32+
float %d2, <2 x float> %v2d2) {
33+
entry:
34+
%t0 = call float @llvm.fpbuiltin.fdiv.f32(float %d1, float %d2) #2
35+
%t1 = call <2 x float> @llvm.fpbuiltin.fdiv.v2f32(<2 x float> %v2d1, <2 x float> %v2d2) #2
36+
ret void
37+
}
38+
39+
declare float @llvm.fpbuiltin.fdiv.f32(float, float)
40+
declare <2 x float> @llvm.fpbuiltin.fdiv.v2f32(<2 x float>, <2 x float>)
41+
42+
; CHECK-LABEL: @test_fdiv_double
43+
; CHECK: %{{.*}} = fdiv double %{{.*}}, %{{.*}}
44+
; CHECK: %{{.*}} = fdiv <2 x double> %{{.*}}, %{{.*}}
45+
define void @test_fdiv_double(double %d1, <2 x double> %v2d1,
46+
double %d2, <2 x double> %v2d2) {
47+
entry:
48+
%t0 = call double @llvm.fpbuiltin.fdiv.f64(double %d1, double %d2) #0
49+
%t1 = call <2 x double> @llvm.fpbuiltin.fdiv.v2f64(<2 x double> %v2d1, <2 x double> %v2d2) #0
50+
ret void
51+
}
52+
53+
declare double @llvm.fpbuiltin.fdiv.f64(double, double)
54+
declare <2 x double> @llvm.fpbuiltin.fdiv.v2f64(<2 x double>, <2 x double>)
55+
56+
; CHECK-LABEL: @test_sqrt
57+
; CHECK: %{{.*}} = call float @llvm.nvvm.sqrt.approx.f(float %{{.*}})
58+
; CHECK: %{{.*}} = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %{{.*}})
59+
define void @test_sqrt(float %d, <2 x float> %v2d, <4 x float> %v4d) {
60+
entry:
61+
%t0 = call float @llvm.fpbuiltin.sqrt.f32(float %d) #1
62+
%t1 = call <2 x float> @llvm.fpbuiltin.sqrt.v2f32(<2 x float> %v2d) #1
63+
ret void
64+
}
65+
66+
; CHECK-LABEL: @test_sqrt_max_error
67+
; CHECK: %{{.*}} = call float @llvm.nvvm.sqrt.approx.f(float %{{.*}})
68+
; CHECK: %{{.*}} = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %{{.*}})
69+
define void @test_sqrt_max_error(float %d, <2 x float> %v2d, <4 x float> %v4d) {
70+
entry:
71+
%t0 = call float @llvm.fpbuiltin.sqrt.f32(float %d) #2
72+
%t1 = call <2 x float> @llvm.fpbuiltin.sqrt.v2f32(<2 x float> %v2d) #2
73+
ret void
74+
}
75+
76+
declare float @llvm.fpbuiltin.sqrt.f32(float)
77+
declare <2 x float> @llvm.fpbuiltin.sqrt.v2f32(<2 x float>)
78+
79+
; CHECK-LABEL: @test_sqrt_double
80+
; CHECK: %{{.*}} = call double @llvm.sqrt.f64(double %{{.*}})
81+
; CHECK: %{{.*}} = call <2 x double> @llvm.sqrt.v2f64(<2 x double> %{{.*}})
82+
define void @test_sqrt_double(double %d, <2 x double> %v2d) {
83+
entry:
84+
%t0 = call double @llvm.fpbuiltin.sqrt.f64(double %d) #1
85+
%t1 = call <2 x double> @llvm.fpbuiltin.sqrt.v2f64(<2 x double> %v2d) #1
86+
ret void
87+
}
88+
89+
declare double @llvm.fpbuiltin.sqrt.f64(double)
90+
declare <2 x double> @llvm.fpbuiltin.sqrt.v2f64(<2 x double>)
91+
92+
attributes #0 = { "fpbuiltin-max-error"="2.5" }
93+
attributes #1 = { "fpbuiltin-max-error"="3.0" }
94+
attributes #2 = { "fpbuiltin-max-error"="10.0" }

0 commit comments

Comments
 (0)