@@ -29,6 +29,8 @@ bool check(float a, float b) {
29
29
return fabs (2 * (a - b) / (a + b)) > bf16_eps * 2 ;
30
30
}
31
31
32
+ bool check (bool a, bool b) { return (a == b); }
33
+
32
34
#define TEST_BUILTIN_1_SCAL_IMPL (NAME ) \
33
35
{ \
34
36
buffer<float > a_buf (&a[0 ], N); \
@@ -46,7 +48,7 @@ bool check(float a, float b) {
46
48
} \
47
49
assert (err == 0 );
48
50
49
- #define TEST_BUILTIN_1_ARR_IMPL (NAME, SZ ) \
51
+ #define TEST_BUILTIN_1_ARR_IMPL (NAME, SZ, RETTY ) \
50
52
{ \
51
53
buffer<float , 2 > a_buf{range<2 >{N / SZ, SZ}}; \
52
54
buffer<int > err_buf (&err, 1 ); \
@@ -59,7 +61,7 @@ bool check(float a, float b) {
59
61
for (int i = 0 ; i < SZ; i++) { \
60
62
arg[i] = A[index][i]; \
61
63
} \
62
- marray<bfloat16 , SZ> res = NAME (arg); \
64
+ marray<RETTY , SZ> res = NAME (arg); \
63
65
for (int i = 0 ; i < SZ; i++) { \
64
66
if (check (res[i], NAME (A[index][i]))) { \
65
67
ERR[0 ] = 1 ; \
@@ -70,13 +72,13 @@ bool check(float a, float b) {
70
72
} \
71
73
assert (err == 0 );
72
74
73
- #define TEST_BUILTIN_1 (NAME ) \
75
+ #define TEST_BUILTIN_1 (NAME, RETTY ) \
74
76
TEST_BUILTIN_1_SCAL_IMPL (NAME) \
75
- TEST_BUILTIN_1_ARR_IMPL(NAME, 1 ) \
76
- TEST_BUILTIN_1_ARR_IMPL(NAME, 2 ) \
77
- TEST_BUILTIN_1_ARR_IMPL(NAME, 3 ) \
78
- TEST_BUILTIN_1_ARR_IMPL(NAME, 4 ) \
79
- TEST_BUILTIN_1_ARR_IMPL(NAME, 5 )
77
+ TEST_BUILTIN_1_ARR_IMPL(NAME, 1 , RETTY) \
78
+ TEST_BUILTIN_1_ARR_IMPL(NAME, 2 , RETTY) \
79
+ TEST_BUILTIN_1_ARR_IMPL(NAME, 3 , RETTY) \
80
+ TEST_BUILTIN_1_ARR_IMPL(NAME, 4 , RETTY) \
81
+ TEST_BUILTIN_1_ARR_IMPL(NAME, 5 , RETTY )
80
82
81
83
#define TEST_BUILTIN_2_SCAL_IMPL (NAME ) \
82
84
{ \
@@ -233,14 +235,18 @@ int main() {
233
235
c[i] = (float )(3 * i);
234
236
}
235
237
236
- TEST_BUILTIN_1 (fabs);
238
+ TEST_BUILTIN_1 (fabs, bfloat16 );
237
239
TEST_BUILTIN_2 (fmin);
238
240
TEST_BUILTIN_2 (fmax);
239
241
TEST_BUILTIN_3 (fma);
240
242
241
243
float check_nan = 0 ;
242
244
TEST_BUILTIN_2_NAN (fmin);
243
245
TEST_BUILTIN_2_NAN (fmax);
246
+
247
+ // Insert NAN value in a to test isnan
248
+ a[0 ] = a[N - 1 ] = NAN;
249
+ TEST_BUILTIN_1 (isnan, bool );
244
250
}
245
251
return 0 ;
246
252
}
0 commit comments