Skip to content

Commit 2de6266

Browse files
unify host and device for bf16's OP macro
1 parent 2bc6c8c commit 2de6266

File tree

1 file changed

+12
-23
lines changed

1 file changed

+12
-23
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -421,19 +421,8 @@ class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {
421421
return Converter<input_type, type>::convert(make_fp32( \
422422
__spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx)) op make_fp32(lhs)); \
423423
}
424-
OP(float, uint16_t, +)
425-
OP(float, uint16_t, -)
426-
OP(float, uint16_t, *)
427-
OP(float, uint16_t, /)
428-
OP(bool, bool, ==)
429-
OP(bool, bool, !=)
430-
OP(bool, bool, <)
431-
OP(bool, bool, >)
432-
OP(bool, bool, <=)
433-
OP(bool, bool, >=)
434-
#undef OP
435424
#else // __SYCL_DEVICE_ONLY__
436-
#define OP(type, op) \
425+
#define OP(input_type, type, op) \
437426
friend type operator op( \
438427
const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs, \
439428
const uint16_t &rhs) { \
@@ -450,18 +439,18 @@ class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {
450439
throw runtime_error("joint matrix is not supported on host device.", \
451440
PI_INVALID_DEVICE); \
452441
}
453-
OP(uint16_t, +)
454-
OP(uint16_t, -)
455-
OP(uint16_t, *)
456-
OP(uint16_t, /)
457-
OP(bool, ==)
458-
OP(bool, !=)
459-
OP(bool, <)
460-
OP(bool, >)
461-
OP(bool, <=)
462-
OP(bool, >=)
463-
#undef OP
464442
#endif // __SYCL_DEVICE_ONLY__
443+
OP(float, uint16_t, +)
444+
OP(float, uint16_t, -)
445+
OP(float, uint16_t, *)
446+
OP(float, uint16_t, /)
447+
OP(bool, bool, ==)
448+
OP(bool, bool, !=)
449+
OP(bool, bool, <)
450+
OP(bool, bool, >)
451+
OP(bool, bool, <=)
452+
OP(bool, bool, >=)
453+
#undef OP
465454
};
466455

467456
template <typename T, size_t NumRows, size_t NumCols, matrix_layout Layout,

0 commit comments

Comments
 (0)