Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit b1ef695

Browse files
authored
Add tests for bfloat16 isnan (#1509)
Signed-off-by: jinge90 <[email protected]>
1 parent b1d77f3 commit b1ef695

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

SYCL/BFloat16/bfloat16_builtins.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ bool check(float a, float b) {
2929
return fabs(2 * (a - b) / (a + b)) > bf16_eps * 2;
3030
}
3131

32+
bool check(bool a, bool b) { return (a == b); }
33+
3234
#define TEST_BUILTIN_1_SCAL_IMPL(NAME) \
3335
{ \
3436
buffer<float> a_buf(&a[0], N); \
@@ -46,7 +48,7 @@ bool check(float a, float b) {
4648
} \
4749
assert(err == 0);
4850

49-
#define TEST_BUILTIN_1_ARR_IMPL(NAME, SZ) \
51+
#define TEST_BUILTIN_1_ARR_IMPL(NAME, SZ, RETTY) \
5052
{ \
5153
buffer<float, 2> a_buf{range<2>{N / SZ, SZ}}; \
5254
buffer<int> err_buf(&err, 1); \
@@ -59,7 +61,7 @@ bool check(float a, float b) {
5961
for (int i = 0; i < SZ; i++) { \
6062
arg[i] = A[index][i]; \
6163
} \
62-
marray<bfloat16, SZ> res = NAME(arg); \
64+
marray<RETTY, SZ> res = NAME(arg); \
6365
for (int i = 0; i < SZ; i++) { \
6466
if (check(res[i], NAME(A[index][i]))) { \
6567
ERR[0] = 1; \
@@ -70,13 +72,13 @@ bool check(float a, float b) {
7072
} \
7173
assert(err == 0);
7274

73-
#define TEST_BUILTIN_1(NAME) \
75+
#define TEST_BUILTIN_1(NAME, RETTY) \
7476
TEST_BUILTIN_1_SCAL_IMPL(NAME) \
75-
TEST_BUILTIN_1_ARR_IMPL(NAME, 1) \
76-
TEST_BUILTIN_1_ARR_IMPL(NAME, 2) \
77-
TEST_BUILTIN_1_ARR_IMPL(NAME, 3) \
78-
TEST_BUILTIN_1_ARR_IMPL(NAME, 4) \
79-
TEST_BUILTIN_1_ARR_IMPL(NAME, 5)
77+
TEST_BUILTIN_1_ARR_IMPL(NAME, 1, RETTY) \
78+
TEST_BUILTIN_1_ARR_IMPL(NAME, 2, RETTY) \
79+
TEST_BUILTIN_1_ARR_IMPL(NAME, 3, RETTY) \
80+
TEST_BUILTIN_1_ARR_IMPL(NAME, 4, RETTY) \
81+
TEST_BUILTIN_1_ARR_IMPL(NAME, 5, RETTY)
8082

8183
#define TEST_BUILTIN_2_SCAL_IMPL(NAME) \
8284
{ \
@@ -233,14 +235,18 @@ int main() {
233235
c[i] = (float)(3 * i);
234236
}
235237

236-
TEST_BUILTIN_1(fabs);
238+
TEST_BUILTIN_1(fabs, bfloat16);
237239
TEST_BUILTIN_2(fmin);
238240
TEST_BUILTIN_2(fmax);
239241
TEST_BUILTIN_3(fma);
240242

241243
float check_nan = 0;
242244
TEST_BUILTIN_2_NAN(fmin);
243245
TEST_BUILTIN_2_NAN(fmax);
246+
247+
// Insert NAN value in a to test isnan
248+
a[0] = a[N - 1] = NAN;
249+
TEST_BUILTIN_1(isnan, bool);
244250
}
245251
return 0;
246252
}

0 commit comments

Comments
 (0)