@@ -277,8 +277,8 @@ class wi_element {
277
277
}
278
278
279
279
#if __SYCL_DEVICE_ONLY__
280
- #define OP (opassign, op ) \
281
- template <typename T2> wi_element &operator opassign (const T2 &rhs) { \
280
+ #define OP (op ) \
281
+ template <typename T2> wi_element &operator op##= (const T2 &rhs) { \
282
282
M.spvm = __spirv_VectorInsertDynamic ( \
283
283
M.spvm , \
284
284
static_cast <T>(__spirv_VectorExtractDynamic (M.spvm , idx) \
@@ -287,17 +287,17 @@ class wi_element {
287
287
return *this ; \
288
288
}
289
289
#else // __SYCL_DEVICE_ONLY__
290
- #define OP (opassign, op ) \
291
- template <typename T2> wi_element &operator opassign (const T2 &rhs) { \
290
+ #define OP (op ) \
291
+ template <typename T2> wi_element &operator op##= (const T2 &rhs) { \
292
292
(void )rhs; \
293
293
throw runtime_error (" joint matrix is not supported on host device." , \
294
294
PI_INVALID_DEVICE); \
295
295
}
296
296
#endif // __SYCL_DEVICE_ONLY__
297
- OP (+=, + )
298
- OP (-=, - )
299
- OP (*=, * )
300
- OP (/=, / )
297
+ OP (+)
298
+ OP (-)
299
+ OP (*)
300
+ OP (/)
301
301
#undef OP
302
302
};
303
303
@@ -377,8 +377,8 @@ class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {
377
377
}
378
378
379
379
#if __SYCL_DEVICE_ONLY__
380
- #define OP (opassign, op ) \
381
- wi_element &operator opassign (const uint16_t &rhs) { \
380
+ #define OP (op ) \
381
+ wi_element &operator op##= (const uint16_t &rhs) { \
382
382
M.spvm = __spirv_VectorInsertDynamic ( \
383
383
M.spvm , \
384
384
make_bf16 (make_fp32 (__spirv_VectorExtractDynamic (M.spvm , idx) \
@@ -387,57 +387,50 @@ class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {
387
387
return *this ; \
388
388
}
389
389
#else // __SYCL_DEVICE_ONLY__
390
- #define OP (opassign, op ) \
391
- wi_element &operator opassign (const uint16_t &rhs) { \
390
+ #define OP (op ) \
391
+ wi_element &operator op##= (const uint16_t &rhs) { \
392
392
(void )rhs; \
393
393
throw runtime_error (" joint matrix is not supported on host device." , \
394
394
PI_INVALID_DEVICE); \
395
395
}
396
396
#endif // __SYCL_DEVICE_ONLY__
397
- OP (+=, + )
398
- OP (-=, - )
399
- OP (*=, * )
400
- OP (/=, / )
397
+ OP (+)
398
+ OP (-)
399
+ OP (*)
400
+ OP (/)
401
401
#undef OP
402
402
403
+ template <typename T1, typename T2> struct Converter {
404
+ static T2 convert (const T1 &from) { return static_cast <T2>(from); }
405
+ };
406
+
407
+ template <typename T> struct Converter <T, uint16_t > {
408
+ static uint16_t convert (const T &from) { return make_bf16 (from); }
409
+ };
403
410
#if __SYCL_DEVICE_ONLY__
404
- #define OP (type, op ) \
411
+ #define OP (input_type, type, op ) \
405
412
friend type operator op ( \
406
413
const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs, \
407
414
const uint16_t &rhs) { \
408
- return make_bf16 (make_fp32 ( \
415
+ return Converter<input_type, type>:: convert (make_fp32 ( \
409
416
__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) op make_fp32 (rhs)); \
410
417
} \
411
418
friend type operator op ( \
412
419
const uint16_t &lhs, \
413
420
const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &rhs) { \
414
- return make_bf16 (make_fp32 ( \
421
+ return Converter<input_type, type>:: convert (make_fp32 ( \
415
422
__spirv_VectorExtractDynamic (rhs.M .spvm , rhs.idx )) op make_fp32 (lhs)); \
416
423
}
417
- OP (uint16_t , +)
418
- OP(uint16_t , -)
419
- OP(uint16_t , *)
420
- OP(uint16_t , /)
421
- #undef OP
422
- #define OP (type, op ) \
423
- friend type operator op ( \
424
- const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs, \
425
- const uint16_t &rhs) { \
426
- return type{make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) \
427
- op make_fp32 (rhs)}; \
428
- } \
429
- friend type operator op ( \
430
- const uint16_t &lhs, \
431
- const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &rhs) { \
432
- return type{make_fp32 (__spirv_VectorExtractDynamic (rhs.M .spvm , rhs.idx )) \
433
- op make_fp32 (lhs)}; \
434
- }
435
- OP (bool , ==)
436
- OP(bool , !=)
437
- OP(bool , <)
438
- OP(bool , >)
439
- OP(bool , <=)
440
- OP(bool , >=)
424
+ OP (float , uint16_t , +)
425
+ OP(float , uint16_t , -)
426
+ OP(float , uint16_t , *)
427
+ OP(float , uint16_t , /)
428
+ OP(bool , bool , ==)
429
+ OP(bool , bool , !=)
430
+ OP(bool , bool , <)
431
+ OP(bool , bool , >)
432
+ OP(bool , bool , <=)
433
+ OP(bool , bool , >=)
441
434
#undef OP
442
435
#else // __SYCL_DEVICE_ONLY__
443
436
#define OP (type, op ) \
0 commit comments