Skip to content

Commit 42eebc1

Browse files
jinge90bb-sycl
authored andcommitted
Add tests for bfloat16 isnan (intel#1509)
Signed-off-by: jinge90 <[email protected]>
1 parent c0ad377 commit 42eebc1

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

SYCL/BFloat16/bfloat16_builtins.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ bool check(float a, float b) {
2828
return fabs(2 * (a - b) / (a + b)) > bf16_eps * 2;
2929
}
3030

31+
bool check(bool a, bool b) { return (a == b); }
32+
3133
#define TEST_BUILTIN_1_SCAL_IMPL(NAME) \
3234
{ \
3335
buffer<float> a_buf(&a[0], N); \
@@ -45,7 +47,7 @@ bool check(float a, float b) {
4547
} \
4648
assert(err == 0);
4749

48-
#define TEST_BUILTIN_1_ARR_IMPL(NAME, SZ) \
50+
#define TEST_BUILTIN_1_ARR_IMPL(NAME, SZ, RETTY) \
4951
{ \
5052
buffer<float, 2> a_buf{range<2>{N / SZ, SZ}}; \
5153
buffer<int> err_buf(&err, 1); \
@@ -58,7 +60,7 @@ bool check(float a, float b) {
5860
for (int i = 0; i < SZ; i++) { \
5961
arg[i] = A[index][i]; \
6062
} \
61-
marray<bfloat16, SZ> res = NAME(arg); \
63+
marray<RETTY, SZ> res = NAME(arg); \
6264
for (int i = 0; i < SZ; i++) { \
6365
if (check(res[i], NAME(A[index][i]))) { \
6466
ERR[0] = 1; \
@@ -69,13 +71,13 @@ bool check(float a, float b) {
6971
} \
7072
assert(err == 0);
7173

72-
#define TEST_BUILTIN_1(NAME) \
74+
#define TEST_BUILTIN_1(NAME, RETTY) \
7375
TEST_BUILTIN_1_SCAL_IMPL(NAME) \
74-
TEST_BUILTIN_1_ARR_IMPL(NAME, 1) \
75-
TEST_BUILTIN_1_ARR_IMPL(NAME, 2) \
76-
TEST_BUILTIN_1_ARR_IMPL(NAME, 3) \
77-
TEST_BUILTIN_1_ARR_IMPL(NAME, 4) \
78-
TEST_BUILTIN_1_ARR_IMPL(NAME, 5)
76+
TEST_BUILTIN_1_ARR_IMPL(NAME, 1, RETTY) \
77+
TEST_BUILTIN_1_ARR_IMPL(NAME, 2, RETTY) \
78+
TEST_BUILTIN_1_ARR_IMPL(NAME, 3, RETTY) \
79+
TEST_BUILTIN_1_ARR_IMPL(NAME, 4, RETTY) \
80+
TEST_BUILTIN_1_ARR_IMPL(NAME, 5, RETTY)
7981

8082
#define TEST_BUILTIN_2_SCAL_IMPL(NAME) \
8183
{ \
@@ -232,14 +234,18 @@ int main() {
232234
c[i] = (float)(3 * i);
233235
}
234236

235-
TEST_BUILTIN_1(fabs);
237+
TEST_BUILTIN_1(fabs, bfloat16);
236238
TEST_BUILTIN_2(fmin);
237239
TEST_BUILTIN_2(fmax);
238240
TEST_BUILTIN_3(fma);
239241

240242
float check_nan = 0;
241243
TEST_BUILTIN_2_NAN(fmin);
242244
TEST_BUILTIN_2_NAN(fmax);
245+
246+
// Insert NAN value in a to test isnan
247+
a[0] = a[N - 1] = NAN;
248+
TEST_BUILTIN_1(isnan, bool);
243249
}
244250
return 0;
245251
}

0 commit comments

Comments
 (0)