Skip to content

Commit 2bc6c8c

Browse files
address steffen's comments
1 parent ca55e6a commit 2bc6c8c

File tree

1 file changed

+36
-43
lines changed

1 file changed

+36
-43
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,8 @@ class wi_element {
277277
}
278278

279279
#if __SYCL_DEVICE_ONLY__
280-
#define OP(opassign, op) \
281-
template <typename T2> wi_element &operator opassign(const T2 &rhs) { \
280+
#define OP(op) \
281+
template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
282282
M.spvm = __spirv_VectorInsertDynamic( \
283283
M.spvm, \
284284
static_cast<T>(__spirv_VectorExtractDynamic(M.spvm, idx) \
@@ -287,17 +287,17 @@ class wi_element {
287287
return *this; \
288288
}
289289
#else // __SYCL_DEVICE_ONLY__
290-
#define OP(opassign, op) \
291-
template <typename T2> wi_element &operator opassign(const T2 &rhs) { \
290+
#define OP(op) \
291+
template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
292292
(void)rhs; \
293293
throw runtime_error("joint matrix is not supported on host device.", \
294294
PI_INVALID_DEVICE); \
295295
}
296296
#endif // __SYCL_DEVICE_ONLY__
297-
OP(+=, +)
298-
OP(-=, -)
299-
OP(*=, *)
300-
OP(/=, /)
297+
OP(+)
298+
OP(-)
299+
OP(*)
300+
OP(/)
301301
#undef OP
302302
};
303303

@@ -377,8 +377,8 @@ class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {
377377
}
378378

379379
#if __SYCL_DEVICE_ONLY__
380-
#define OP(opassign, op) \
381-
wi_element &operator opassign(const uint16_t &rhs) { \
380+
#define OP(op) \
381+
wi_element &operator op##=(const uint16_t &rhs) { \
382382
M.spvm = __spirv_VectorInsertDynamic( \
383383
M.spvm, \
384384
make_bf16(make_fp32(__spirv_VectorExtractDynamic(M.spvm, idx) \
@@ -387,57 +387,50 @@ class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {
387387
return *this; \
388388
}
389389
#else // __SYCL_DEVICE_ONLY__
390-
#define OP(opassign, op) \
391-
wi_element &operator opassign(const uint16_t &rhs) { \
390+
#define OP(op) \
391+
wi_element &operator op##=(const uint16_t &rhs) { \
392392
(void)rhs; \
393393
throw runtime_error("joint matrix is not supported on host device.", \
394394
PI_INVALID_DEVICE); \
395395
}
396396
#endif // __SYCL_DEVICE_ONLY__
397-
OP(+=, +)
398-
OP(-=, -)
399-
OP(*=, *)
400-
OP(/=, /)
397+
OP(+)
398+
OP(-)
399+
OP(*)
400+
OP(/)
401401
#undef OP
402402

403+
template <typename T1, typename T2> struct Converter {
404+
static T2 convert(const T1 &from) { return static_cast<T2>(from); }
405+
};
406+
407+
template <typename T> struct Converter<T, uint16_t> {
408+
static uint16_t convert(const T &from) { return make_bf16(from); }
409+
};
403410
#if __SYCL_DEVICE_ONLY__
404-
#define OP(type, op) \
411+
#define OP(input_type, type, op) \
405412
friend type operator op( \
406413
const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs, \
407414
const uint16_t &rhs) { \
408-
return make_bf16(make_fp32( \
415+
return Converter<input_type, type>::convert(make_fp32( \
409416
__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) op make_fp32(rhs)); \
410417
} \
411418
friend type operator op( \
412419
const uint16_t &lhs, \
413420
const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &rhs) { \
414-
return make_bf16(make_fp32( \
421+
return Converter<input_type, type>::convert(make_fp32( \
415422
__spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx)) op make_fp32(lhs)); \
416423
}
417-
OP(uint16_t, +)
418-
OP(uint16_t, -)
419-
OP(uint16_t, *)
420-
OP(uint16_t, /)
421-
#undef OP
422-
#define OP(type, op) \
423-
friend type operator op( \
424-
const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &lhs, \
425-
const uint16_t &rhs) { \
426-
return type{make_fp32(__spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx)) \
427-
op make_fp32(rhs)}; \
428-
} \
429-
friend type operator op( \
430-
const uint16_t &lhs, \
431-
const wi_element<uint16_t, NumRows, NumCols, Layout, Group> &rhs) { \
432-
return type{make_fp32(__spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx)) \
433-
op make_fp32(lhs)}; \
434-
}
435-
OP(bool, ==)
436-
OP(bool, !=)
437-
OP(bool, <)
438-
OP(bool, >)
439-
OP(bool, <=)
440-
OP(bool, >=)
424+
OP(float, uint16_t, +)
425+
OP(float, uint16_t, -)
426+
OP(float, uint16_t, *)
427+
OP(float, uint16_t, /)
428+
OP(bool, bool, ==)
429+
OP(bool, bool, !=)
430+
OP(bool, bool, <)
431+
OP(bool, bool, >)
432+
OP(bool, bool, <=)
433+
OP(bool, bool, >=)
441434
#undef OP
442435
#else // __SYCL_DEVICE_ONLY__
443436
#define OP(type, op) \

0 commit comments

Comments
 (0)