Skip to content

Commit db1d3b2

Browse files
authored
[HIP] Fix __clang_hip_cmath.hip for ambiguity (#101341)
If there is a type T which can be converted to both float and double etc but itself is not specialized for __numeric_type, and it is called for math functions eg. fma, it will cause ambiguity with test function of __numeric_type. Since test is not template, this error is not bypassed by SFINAE. This is a design flaw of __numeric_type. This patch fixes clang wrapper header to use SFINAE to avoid such ambiguity. Fixes: SWDEV-461604 Fixes: #101239
1 parent 88ef76c commit db1d3b2

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

clang/lib/Headers/__clang_hip_cmath.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,12 @@ template <class _Tp> struct __numeric_type {
395395
// No support for long double, use double instead.
396396
static double __test(long double);
397397

398-
typedef decltype(__test(declval<_Tp>())) type;
398+
template <typename _U>
399+
static auto __test_impl(int) -> decltype(__test(declval<_U>()));
400+
401+
template <typename _U> static void __test_impl(...);
402+
403+
typedef decltype(__test_impl<_Tp>(0)) type;
399404
static const bool value = !is_same<type, void>::value;
400405
};
401406

clang/test/Headers/__clang_hip_cmath.hip

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,22 @@ extern "C" __device__ float test_sin_f32(float x) {
8787
extern "C" __device__ float test_cos_f32(float x) {
8888
return cos(x);
8989
}
90+
91+
// Check user defined type which can be converted to float and double but not
92+
// specializes __numeric_type will not cause ambiguity diagnostics.
93+
struct user_bfloat16 {
94+
__host__ __device__ user_bfloat16(float);
95+
operator float();
96+
operator double();
97+
};
98+
99+
namespace user_namespace {
100+
__device__ user_bfloat16 fma(const user_bfloat16 a, const user_bfloat16 b, const user_bfloat16 c) {
101+
return a;
102+
}
103+
104+
__global__ void test_fma() {
105+
user_bfloat16 a = 1.0f, b = 2.0f;
106+
fma(a, b, b);
107+
}
108+
}

0 commit comments

Comments
 (0)