@@ -36,7 +36,9 @@ typedef std::map<
36
36
std::type_index,
37
37
std::variant<
38
38
std::vector<float >,
39
- std::vector<double >>>
39
+ std::vector<double >,
40
+ std::vector<exec_aten::Half>,
41
+ std::vector<exec_aten::BFloat16>>>
40
42
FloatingTypeToDataMap;
41
43
42
44
typedef std::map<
@@ -309,9 +311,9 @@ TEST_F(OpToTest, AllDtypesSupported) {
309
311
ScalarType::OUTPUT_DTYPE>(test_cases);
310
312
311
313
#define TEST_ENTRY (INPUT_CTYPE, INPUT_DTYPE ) \
312
- ET_FORALL_REAL_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
314
+ ET_FORALL_REALHBF16_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
313
315
314
- ET_FORALL_REAL_TYPES (TEST_ENTRY);
316
+ ET_FORALL_REALHBF16_TYPES (TEST_ENTRY);
315
317
316
318
#undef TEST_ENTRY
317
319
#undef TEST_KERNEL
@@ -323,14 +325,14 @@ TEST_F(OpToTest, BoolTests) {
323
325
#define TEST_TO_BOOL (INPUT_CTYPE, INPUT_DTYPE ) \
324
326
test_runner_to_bool<INPUT_CTYPE, ScalarType::INPUT_DTYPE>( \
325
327
test_case_to_bool, result_to_bool);
326
- ET_FORALL_REAL_TYPES (TEST_TO_BOOL);
328
+ ET_FORALL_REALHBF16_TYPES (TEST_TO_BOOL);
327
329
328
330
std::vector<uint8_t > test_case_from_bool = {true , true , false };
329
331
std::vector<double > result_from_bool = {1.0 , 1.0 , 0 };
330
332
#define TEST_FROM_BOOL (OUTPUT_CTYPE, OUTPUT_DTYPE ) \
331
333
test_runner_from_bool<OUTPUT_CTYPE, ScalarType::OUTPUT_DTYPE>( \
332
334
test_case_from_bool, result_from_bool);
333
- ET_FORALL_REAL_TYPES (TEST_FROM_BOOL);
335
+ ET_FORALL_REALHBF16_TYPES (TEST_FROM_BOOL);
334
336
}
335
337
336
338
TEST_F (OpToTest, NanInfSupported) {
@@ -349,9 +351,9 @@ TEST_F(OpToTest, NanInfSupported) {
349
351
ScalarType::OUTPUT_DTYPE>(test_cases);
350
352
351
353
#define TEST_ENTRY (INPUT_CTYPE, INPUT_DTYPE ) \
352
- ET_FORALL_FLOAT_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
354
+ ET_FORALL_FLOATHBF16_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
353
355
354
- ET_FORALL_FLOAT_TYPES (TEST_ENTRY);
356
+ ET_FORALL_FLOATHBF16_TYPES (TEST_ENTRY);
355
357
356
358
#undef TEST_ENTRY
357
359
#undef TEST_KERNEL
@@ -381,6 +383,13 @@ TEST_F(OpToTest, HardcodeFloatConvertInt) {
381
383
-0.30919688936285893988 };
382
384
// clang-format on
383
385
386
+ std::vector<exec_aten::Half> half_data;
387
+ std::vector<exec_aten::BFloat16> bf16_data;
388
+ for (auto d : double_data) {
389
+ half_data.emplace_back (d);
390
+ bf16_data.emplace_back (d);
391
+ }
392
+
384
393
std::vector<int64_t > int64_data = {
385
394
-1 , -4 , 2 , -2 , 3 , 3 , -3 , -4 , 3 , 3 , 0 , 2 , 0 , -1 , 0 };
386
395
std::vector<int32_t > int32_data = {
@@ -394,6 +403,8 @@ TEST_F(OpToTest, HardcodeFloatConvertInt) {
394
403
FloatingTypeToDataMap floating_point_data;
395
404
floating_point_data[typeid (float )] = float_data;
396
405
floating_point_data[typeid (double )] = double_data;
406
+ floating_point_data[typeid (exec_aten::Half)] = half_data;
407
+ floating_point_data[typeid (exec_aten::BFloat16)] = bf16_data;
397
408
398
409
// Gathering all int data together for better traversial
399
410
IntTypeToDataMap int_data;
@@ -412,7 +423,7 @@ TEST_F(OpToTest, HardcodeFloatConvertInt) {
412
423
#define TEST_ENTRY (INPUT_CTYPE, INPUT_DTYPE ) \
413
424
ET_FORALL_INT_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
414
425
415
- ET_FORALL_FLOAT_TYPES (TEST_ENTRY);
426
+ ET_FORALL_FLOATHBF16_TYPES (TEST_ENTRY);
416
427
}
417
428
418
429
TEST_F (OpToTest, MismatchedSizesDie) {
0 commit comments