Skip to content

Commit 3ad9ba5

Browse files
yxsamliuDavid Salinas
authored andcommitted
[HIP] Fix __clang_hip_cmath.hip for ambiguity (llvm#101341)
If there is a type T which can be converted to both float and double etc but itself is not specialized for __numeric_type, and it is called for math functions eg. fma, it will cause ambiguity with test function of __numeric_type. Since test is not template, this error is not bypassed by SFINAE. This is a design flaw of __numeric_type. This patch fixes clang wrapper header to use SFINAE to avoid such ambiguity. Fixes: SWDEV-461604 Fixes: llvm#101239 Change-Id: I285adbeb7bf8fe57d084625b98fa6fd49092d641
1 parent 2d9cc50 commit 3ad9ba5

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

clang/lib/Headers/__clang_hip_cmath.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,12 @@ template <class _Tp> struct __numeric_type {
397397
// No support for long double, use double instead.
398398
static double __test(long double);
399399

400-
typedef decltype(__test(declval<_Tp>())) type;
400+
template <typename _U>
401+
static auto __test_impl(int) -> decltype(__test(declval<_U>()));
402+
403+
template <typename _U> static void __test_impl(...);
404+
405+
typedef decltype(__test_impl<_Tp>(0)) type;
401406
static const bool value = !is_same<type, void>::value;
402407
};
403408

clang/test/Headers/__clang_hip_cmath.hip

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,22 @@ extern "C" __device__ float test_sin_f32(float x) {
8686
extern "C" __device__ float test_cos_f32(float x) {
8787
return cos(x);
8888
}
89+
90+
// Check user defined type which can be converted to float and double but not
91+
// specializes __numeric_type will not cause ambiguity diagnostics.
92+
struct user_bfloat16 {
93+
__host__ __device__ user_bfloat16(float);
94+
operator float();
95+
operator double();
96+
};
97+
98+
namespace user_namespace {
99+
__device__ user_bfloat16 fma(const user_bfloat16 a, const user_bfloat16 b, const user_bfloat16 c) {
100+
return a;
101+
}
102+
103+
__global__ void test_fma() {
104+
user_bfloat16 a = 1.0f, b = 2.0f;
105+
fma(a, b, b);
106+
}
107+
}

0 commit comments

Comments
 (0)