Skip to content

Commit 48bb501

Browse files
authored
[SYCL] Emit fpbuiltin version of function only for function with FP arguments. (#17253)
When the arguments of a function are of half types, such as `sqrt(half_type v)` they are expanded to: `load ptr addrspace(4), ptr addrspace(4) %a.addr.ascast, align 8, !tbaa !5` This will generate a crash in the BE due to type mismatch. The expected type of the function is `float`. This patch fixes the issue by restricting the emission of `fpbuiltin` only to functions that take floating point values arguments.
1 parent 416eb46 commit 48bb501

File tree

3 files changed

+252
-79
lines changed

3 files changed

+252
-79
lines changed

clang/lib/CodeGen/CGCall.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5905,7 +5905,22 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
59055905
bool isFp32SqrtFunction =
59065906
(FuncName == "sqrt" && !getLangOpts().OffloadFP32PrecSqrt &&
59075907
IsFloat32Type);
5908-
if (hasFPAccuracyFuncMap || hasFPAccuracyVal || isFp32SqrtFunction) {
5908+
bool ArgsTypeIsFloat = true;
5909+
// In sycl mode, functions' arguments of type half are expanded
5910+
// to pointer types. Exclude these functions from being emitted
5911+
// as fpbuiltins.
5912+
if (!getLangOpts().OffloadFP32PrecSqrt ||
5913+
!getLangOpts().OffloadFP32PrecDiv) {
5914+
for (auto &Arg : IRCallArgs) {
5915+
if (!Arg->getType()->isFPOrFPVectorTy() &&
5916+
!Arg->getType()->isIntOrIntVectorTy()) {
5917+
ArgsTypeIsFloat = false;
5918+
break;
5919+
}
5920+
}
5921+
}
5922+
if (ArgsTypeIsFloat &&
5923+
(hasFPAccuracyFuncMap || hasFPAccuracyVal || isFp32SqrtFunction)) {
59095924
CI = MaybeEmitFPBuiltinofFD(IRFuncTy, IRCallArgs, CalleePtr,
59105925
FD->getName(), FD->getBuiltinID());
59115926
if (CI)
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
#define __SYCL_CONSTEXPR_HALF constexpr
2+
using StorageT = _Float16;
3+
4+
class half {
5+
public:
6+
half() = default;
7+
constexpr half(const half &) = default;
8+
constexpr half(half &&) = default;
9+
10+
__SYCL_CONSTEXPR_HALF half(const float &rhs) : Data(rhs) {}
11+
12+
constexpr half &operator=(const half &rhs) = default;
13+
14+
__SYCL_CONSTEXPR_HALF half &operator/=(const half &rhs) {
15+
Data /= rhs.Data;
16+
return *this;
17+
}
18+
19+
#define OP(op, op_eq) \
20+
__SYCL_CONSTEXPR_HALF friend half operator op(const half lhs, \
21+
const half rhs) { \
22+
half rtn = lhs; \
23+
rtn op_eq rhs; \
24+
return rtn; \
25+
} \
26+
__SYCL_CONSTEXPR_HALF friend double operator op(const half lhs, \
27+
const double rhs) { \
28+
double rtn = lhs; \
29+
rtn op_eq rhs; \
30+
return rtn; \
31+
} \
32+
__SYCL_CONSTEXPR_HALF friend double operator op(const double lhs, \
33+
const half rhs) { \
34+
double rtn = lhs; \
35+
rtn op_eq rhs; \
36+
return rtn; \
37+
} \
38+
__SYCL_CONSTEXPR_HALF friend float operator op(const half lhs, \
39+
const float rhs) { \
40+
float rtn = lhs; \
41+
rtn op_eq rhs; \
42+
return rtn; \
43+
} \
44+
__SYCL_CONSTEXPR_HALF friend float operator op(const float lhs, \
45+
const half rhs) { \
46+
float rtn = lhs; \
47+
rtn op_eq rhs; \
48+
return rtn; \
49+
} \
50+
__SYCL_CONSTEXPR_HALF friend half operator op(const half lhs, \
51+
const int rhs) { \
52+
half rtn = lhs; \
53+
rtn op_eq rhs; \
54+
return rtn; \
55+
} \
56+
__SYCL_CONSTEXPR_HALF friend half operator op(const int lhs, \
57+
const half rhs) { \
58+
half rtn = lhs; \
59+
rtn op_eq rhs; \
60+
return rtn; \
61+
} \
62+
__SYCL_CONSTEXPR_HALF friend half operator op(const half lhs, \
63+
const long rhs) { \
64+
half rtn = lhs; \
65+
rtn op_eq rhs; \
66+
return rtn; \
67+
} \
68+
__SYCL_CONSTEXPR_HALF friend half operator op(const long lhs, \
69+
const half rhs) { \
70+
half rtn = lhs; \
71+
rtn op_eq rhs; \
72+
return rtn; \
73+
} \
74+
__SYCL_CONSTEXPR_HALF friend half operator op(const half lhs, \
75+
const long long rhs) { \
76+
half rtn = lhs; \
77+
rtn op_eq rhs; \
78+
return rtn; \
79+
} \
80+
__SYCL_CONSTEXPR_HALF friend half operator op(const long long lhs, \
81+
const half rhs) { \
82+
half rtn = lhs; \
83+
rtn op_eq rhs; \
84+
return rtn; \
85+
} \
86+
__SYCL_CONSTEXPR_HALF friend half operator op(const half &lhs, \
87+
const unsigned int &rhs) { \
88+
half rtn = lhs; \
89+
rtn op_eq rhs; \
90+
return rtn; \
91+
} \
92+
__SYCL_CONSTEXPR_HALF friend half operator op(const unsigned int &lhs, \
93+
const half &rhs) { \
94+
half rtn = lhs; \
95+
rtn op_eq rhs; \
96+
return rtn; \
97+
} \
98+
__SYCL_CONSTEXPR_HALF friend half operator op(const half &lhs, \
99+
const unsigned long &rhs) { \
100+
half rtn = lhs; \
101+
rtn op_eq rhs; \
102+
return rtn; \
103+
} \
104+
__SYCL_CONSTEXPR_HALF friend half operator op(const unsigned long &lhs, \
105+
const half &rhs) { \
106+
half rtn = lhs; \
107+
rtn op_eq rhs; \
108+
return rtn; \
109+
} \
110+
__SYCL_CONSTEXPR_HALF friend half operator op( \
111+
const half &lhs, const unsigned long long &rhs) { \
112+
half rtn = lhs; \
113+
rtn op_eq rhs; \
114+
return rtn; \
115+
} \
116+
__SYCL_CONSTEXPR_HALF friend half operator op(const unsigned long long &lhs, \
117+
const half &rhs) { \
118+
half rtn = lhs; \
119+
rtn op_eq rhs; \
120+
return rtn; \
121+
}
122+
OP(/, /=)
123+
124+
#undef OP
125+
126+
// Operator float
127+
__SYCL_CONSTEXPR_HALF operator float() const {
128+
return static_cast<float>(Data);
129+
}
130+
131+
private:
132+
__SYCL_CONSTEXPR_HALF StorageT getFPRep() const { return Data; }
133+
134+
StorageT Data;
135+
};
136+

0 commit comments

Comments
 (0)