Skip to content

Commit ebf9ea8

Browse files
authored
[SYCL][CUDA] Add -fcuda-prec-sqrt flag (#5141)
This patch add `__nvvm_reflect` support for `__CUDA_PREC_SQRT` and adds a `-Xclang -fcuda-prec-sqrt` flag which is equivalent to the `nvcc` `-prec-sqrt` flag, except that it defaults to `false` for `clang++` and to `true` for `nvcc`. The reason for that is that the SYCL specification doesn't require a correctly rounded `sqrt` so we likely want to keep the fast `sqrt` as a default and use the flag when higher precision is required. See additional discussion on #4041 and #5116
1 parent 87dae23 commit ebf9ea8

File tree

9 files changed

+56
-12
lines changed

9 files changed

+56
-12
lines changed

clang/include/clang/Basic/TargetOptions.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ class TargetOptions {
7575
/// address space.
7676
bool NVPTXUseShortPointers = false;
7777

78+
/// \brief If enabled, use precise square root
79+
bool NVVMCudaPrecSqrt = false;
80+
7881
/// \brief If enabled, allow AMDGPU unsafe floating point atomics.
7982
bool AllowAMDGPUUnsafeFPAtomics = false;
8083

clang/include/clang/Driver/Options.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,11 @@ defm cuda_short_ptr : BoolFOption<"cuda-short-ptr",
980980
TargetOpts<"NVPTXUseShortPointers">, DefaultFalse,
981981
PosFlag<SetTrue, [CC1Option], "Use 32-bit pointers for accessing const/local/shared address spaces">,
982982
NegFlag<SetFalse>>;
983+
defm cuda_prec_sqrt : BoolFOption<"cuda-prec-sqrt",
984+
TargetOpts<"NVVMCudaPrecSqrt">, DefaultFalse,
985+
PosFlag<SetTrue, [CC1Option], "Specify">,
986+
NegFlag<SetFalse, [], "Don't specify">,
987+
BothFlags<[], " that sqrt is correctly rounded (for CUDA devices)">>;
983988
def rocm_path_EQ : Joined<["--"], "rocm-path=">, Group<i_Group>,
984989
HelpText<"ROCm installation path, used for finding and automatically linking required bitcode libraries.">;
985990
def hip_path_EQ : Joined<["--"], "hip-path=">, Group<i_Group>,

clang/lib/CodeGen/CodeGenModule.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,13 +771,15 @@ void CodeGenModule::Release() {
771771
llvm::MDString::get(Ctx, CodeGenOpts.MemoryProfileOutput));
772772
}
773773

774-
if (LangOpts.CUDAIsDevice && getTriple().isNVPTX()) {
774+
if ((LangOpts.CUDAIsDevice || LangOpts.isSYCL()) && getTriple().isNVPTX()) {
775775
// Indicate whether __nvvm_reflect should be configured to flush denormal
776776
// floating point values to 0. (This corresponds to its "__CUDA_FTZ"
777777
// property.)
778778
getModule().addModuleFlag(llvm::Module::Override, "nvvm-reflect-ftz",
779779
CodeGenOpts.FP32DenormalMode.Output !=
780780
llvm::DenormalMode::IEEE);
781+
getModule().addModuleFlag(llvm::Module::Override, "nvvm-reflect-prec-sqrt",
782+
getTarget().getTargetOpts().NVVMCudaPrecSqrt);
781783
}
782784

783785
if (LangOpts.EHAsynch)

clang/test/CodeGenCUDA/flush-denormals.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ extern "C" __device__ void foo() {}
4444
// FTZ: attributes #0 = {{.*}} "denormal-fp-math-f32"="preserve-sign,preserve-sign"
4545
// NOFTZ-NOT: "denormal-fp-math-f32"
4646

47-
// PTXFTZ:!llvm.module.flags = !{{{.*}}[[MODFLAG:![0-9]+]]}
47+
// PTXFTZ:!llvm.module.flags = !{{{.*}}, [[MODFLAG:![0-9]+]], {{.*}}}
4848
// PTXFTZ:[[MODFLAG]] = !{i32 4, !"nvvm-reflect-ftz", i32 1}
4949

50-
// PTXNOFTZ:!llvm.module.flags = !{{{.*}}[[MODFLAG:![0-9]+]]}
50+
// PTXNOFTZ:!llvm.module.flags = !{{{.*}}, [[MODFLAG:![0-9]+]], {{.*}}}
5151
// PTXNOFTZ:[[MODFLAG]] = !{i32 4, !"nvvm-reflect-ftz", i32 0}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: %clang_cc1 -fcuda-is-device -triple nvptx64-nvidia-cuda -emit-llvm -fcuda-prec-sqrt %s -o -| FileCheck --check-prefix=CHECK-ON %s
2+
// RUN: %clang_cc1 -fcuda-is-device -triple nvptx64-nvidia-cuda -emit-llvm %s -o -| FileCheck --check-prefix=CHECK-OFF %s
3+
4+
#include "Inputs/cuda.h"
5+
6+
// Check that the -fcuda-prec-sqrt flag correctly sets the nvvm-reflect module flags.
7+
8+
extern "C" __device__ void foo() {}
9+
10+
// CHECK-ON: !{i32 4, !"nvvm-reflect-prec-sqrt", i32 1}
11+
// CHECK-OFF: !{i32 4, !"nvvm-reflect-prec-sqrt", i32 0}

llvm/docs/NVPTXUsage.rst

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -343,19 +343,22 @@ Reflection Parameters
343343
The libdevice library currently uses the following reflection parameters to
344344
control code generation:
345345

346-
==================== ======================================================
347-
Flag Description
348-
==================== ======================================================
349-
``__CUDA_FTZ=[0,1]`` Use optimized code paths that flush subnormals to zero
350-
==================== ======================================================
346+
=========================== ======================================================
347+
Flag Description
348+
=========================== ======================================================
349+
``__CUDA_FTZ=[0,1]`` Use optimized code paths that flush subnormals to zero
350+
``__CUDA_PREC_SQRT=[0,1]`` Use precise square root
351+
=========================== ======================================================
351352

352-
The value of this flag is determined by the "nvvm-reflect-ftz" module flag.
353-
The following sets the ftz flag to 1.
353+
The value of these flags are determined by the "nvvm-reflect-ftz" and
354+
"nvvm-reflect-prec-sqrt" module flags respectively.
355+
The following sets the ftz flag to 1, and the precise sqrt flag to 1.
354356

355357
.. code-block:: llvm
356358
357-
!llvm.module.flag = !{!0}
359+
!llvm.module.flag = !{!0, !1}
358360
!0 = !{i32 4, !"nvvm-reflect-ftz", i32 1}
361+
!1 = !{i32 4, !"nvvm-reflect-prec-sqrt", i32 1}
359362
360363
(``i32 4`` indicates that the value set here overrides the value in another
361364
module we link with. See the `LangRef <LangRef.html#module-flags-metadata>`

llvm/lib/Target/NVPTX/NVVMReflect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
170170
ReflectVal = Flag->getSExtValue();
171171
} else if (ReflectArg == "__CUDA_ARCH") {
172172
ReflectVal = SmVersion * 10;
173+
} else if (ReflectArg == "__CUDA_PREC_SQRT") {
174+
// Try to pull __CUDA_PREC_SQRT from the nvvm-reflect-prec-sqrt module
175+
// flag.
176+
if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
177+
F.getParent()->getModuleFlag("nvvm-reflect-prec-sqrt")))
178+
ReflectVal = Flag->getSExtValue();
173179
}
174180
Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
175181
ToRemove.push_back(Call);

llvm/test/CodeGen/NVPTX/nvvm-reflect-module-flag.ll

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,20 @@
33

44
declare i32 @__nvvm_reflect(i8*)
55
@str = private unnamed_addr addrspace(1) constant [11 x i8] c"__CUDA_FTZ\00"
6+
@str.1 = private unnamed_addr addrspace(1) constant [17 x i8] c"__CUDA_PREC_SQRT\00"
67

78
define i32 @foo() {
89
%call = call i32 @__nvvm_reflect(i8* addrspacecast (i8 addrspace(1)* getelementptr inbounds ([11 x i8], [11 x i8] addrspace(1)* @str, i32 0, i32 0) to i8*))
910
; CHECK: ret i32 42
1011
ret i32 %call
1112
}
1213

13-
!llvm.module.flags = !{!0}
14+
define i32 @foo_sqrt() {
15+
%call = call i32 @__nvvm_reflect(i8* addrspacecast (i8 addrspace(1)* getelementptr inbounds ([17 x i8], [17 x i8] addrspace(1)* @str.1, i32 0, i32 0) to i8*))
16+
; CHECK: ret i32 42
17+
ret i32 %call
18+
}
19+
20+
!llvm.module.flags = !{!0, !1}
1421
!0 = !{i32 4, !"nvvm-reflect-ftz", i32 42}
22+
!1 = !{i32 4, !"nvvm-reflect-prec-sqrt", i32 42}

sycl/doc/GetStartedGuide.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,12 @@ which contains all the symbols required.
830830
GPU (SM 71), but it should work on any GPU compatible with SM 50 or above
831831
* The NVIDIA OpenCL headers conflict with the OpenCL headers required for this
832832
project and may cause compilation issues on some platforms
833+
* `sycl::sqrt` is not correctly rounded by default as the SYCL specification
834+
allows lower precision, when porting from CUDA it may be helpful to use
835+
`-Xclang -fcuda-prec-sqrt` to use the correctly rounded square root, this is
836+
significantly slower but matches the default precision used by `nvcc`, and
837+
this `clang++` flag is equivalent to the `nvcc` `-prec-sqrt` flag, except that
838+
it defaults to `false`.
833839
834840
### HIP back-end limitations
835841

0 commit comments

Comments
 (0)