10
10
11
11
#include < CL/__spirv/spirv_ops.hpp>
12
12
#include < sycl/detail/defines_elementary.hpp>
13
- #include < sycl/ext/oneapi/experimental/ bfloat16.hpp>
13
+ #include < sycl/ext/oneapi/bfloat16.hpp>
14
14
#include < sycl/feature_test.hpp>
15
15
16
16
__SYCL_INLINE_NAMESPACE (cl) {
@@ -458,18 +458,16 @@ class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {
458
458
};
459
459
460
460
template <size_t NumRows, size_t NumCols, matrix_layout Layout, typename Group>
461
- class wi_element <sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
462
- Layout, Group> {
463
- joint_matrix<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
464
- Layout, Group> &M;
461
+ class wi_element <sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, Group> {
462
+ joint_matrix<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, Group> &M;
465
463
std::size_t idx;
466
464
467
465
public:
468
- wi_element (joint_matrix<sycl::ext::oneapi::experimental:: bfloat16, NumRows,
469
- NumCols, Layout, Group> &Mat,
466
+ wi_element (joint_matrix<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout ,
467
+ Group> &Mat,
470
468
std::size_t i)
471
469
: M(Mat), idx(i) {}
472
- operator sycl::ext::oneapi::experimental:: bfloat16 () {
470
+ operator sycl::ext::oneapi::bfloat16 () {
473
471
#ifdef __SYCL_DEVICE_ONLY__
474
472
return __spirv_VectorExtractDynamic (M.spvm , idx);
475
473
#else
@@ -488,7 +486,7 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
488
486
#endif // __SYCL_DEVICE_ONLY__
489
487
}
490
488
491
- wi_element &operator =(const sycl::ext::oneapi::experimental:: bfloat16 &rhs) {
489
+ wi_element &operator =(const sycl::ext::oneapi::bfloat16 &rhs) {
492
490
#ifdef __SYCL_DEVICE_ONLY__
493
491
M.spvm = __spirv_VectorInsertDynamic (M.spvm , rhs, idx);
494
492
return *this ;
@@ -499,9 +497,8 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
499
497
#endif // __SYCL_DEVICE_ONLY__
500
498
}
501
499
502
- wi_element &
503
- operator =(const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows,
504
- NumCols, Layout, Group> &rhs) {
500
+ wi_element &operator =(const wi_element<sycl::ext::oneapi::bfloat16, NumRows,
501
+ NumCols, Layout, Group> &rhs) {
505
502
#ifdef __SYCL_DEVICE_ONLY__
506
503
M.spvm = __spirv_VectorInsertDynamic (
507
504
M.spvm , __spirv_VectorExtractDynamic (rhs.M .spvm , rhs.idx ), idx);
@@ -515,16 +512,14 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
515
512
516
513
#if __SYCL_DEVICE_ONLY__
517
514
#define OP (opassign, op ) \
518
- wi_element &operator opassign ( \
519
- const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
515
+ wi_element &operator opassign (const sycl::ext::oneapi::bfloat16 &rhs) { \
520
516
M.spvm = __spirv_VectorInsertDynamic ( \
521
517
M.spvm , __spirv_VectorExtractDynamic (M.spvm , idx) op rhs, idx); \
522
518
return *this ; \
523
519
}
524
520
#else // __SYCL_DEVICE_ONLY__
525
521
#define OP (opassign, op ) \
526
- wi_element &operator opassign ( \
527
- const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \
522
+ wi_element &operator opassign (const sycl::ext::oneapi::bfloat16 &rhs) { \
528
523
(void )rhs; \
529
524
throw runtime_error (" joint matrix is not supported on host device." , \
530
525
PI_ERROR_INVALID_DEVICE); \
@@ -539,34 +534,34 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
539
534
#if __SYCL_DEVICE_ONLY__
540
535
#define OP (type, op ) \
541
536
friend type operator op ( \
542
- const wi_element<sycl::ext::oneapi::experimental:: bfloat16, NumRows, \
543
- NumCols, Layout, Group> &lhs, \
544
- const sycl::ext::oneapi::experimental:: bfloat16 &rhs) { \
537
+ const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
538
+ Group> &lhs, \
539
+ const sycl::ext::oneapi::bfloat16 &rhs) { \
545
540
return __spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx ) op rhs; \
546
541
} \
547
542
friend type operator op ( \
548
- const sycl::ext::oneapi::experimental:: bfloat16 &lhs, \
549
- const wi_element<sycl::ext::oneapi::experimental:: bfloat16, NumRows, \
550
- NumCols, Layout, Group> &rhs) { \
543
+ const sycl::ext::oneapi::bfloat16 &lhs, \
544
+ const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
545
+ Group> &rhs) { \
551
546
return __spirv_VectorExtractDynamic (rhs.M .spvm , rhs.idx ) op lhs; \
552
547
}
553
- OP (sycl::ext::oneapi::experimental:: bfloat16, +)
554
- OP(sycl::ext::oneapi::experimental:: bfloat16, -)
555
- OP(sycl::ext::oneapi::experimental:: bfloat16, *)
556
- OP(sycl::ext::oneapi::experimental:: bfloat16, /)
548
+ OP (sycl::ext::oneapi::bfloat16, +)
549
+ OP(sycl::ext::oneapi::bfloat16, -)
550
+ OP(sycl::ext::oneapi::bfloat16, *)
551
+ OP(sycl::ext::oneapi::bfloat16, /)
557
552
#undef OP
558
553
#define OP (type, op ) \
559
554
friend type operator op ( \
560
- const wi_element<sycl::ext::oneapi::experimental:: bfloat16, NumRows, \
561
- NumCols, Layout, Group> &lhs, \
562
- const sycl::ext::oneapi::experimental:: bfloat16 &rhs) { \
555
+ const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
556
+ Group> &lhs, \
557
+ const sycl::ext::oneapi::bfloat16 &rhs) { \
563
558
return type{static_cast <float >(__spirv_VectorExtractDynamic ( \
564
559
lhs.M .spvm , lhs.idx )) op static_cast <float >(rhs)}; \
565
560
} \
566
561
friend type operator op ( \
567
- const sycl::ext::oneapi::experimental:: bfloat16 &lhs, \
568
- const wi_element<sycl::ext::oneapi::experimental:: bfloat16, NumRows, \
569
- NumCols, Layout, Group> &rhs) { \
562
+ const sycl::ext::oneapi::bfloat16 &lhs, \
563
+ const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
564
+ Group> &rhs) { \
570
565
return type{static_cast <float >(__spirv_VectorExtractDynamic ( \
571
566
rhs.M .spvm , rhs.idx )) op static_cast <float >(lhs)}; \
572
567
}
@@ -579,24 +574,23 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
579
574
#undef OP
580
575
#else // __SYCL_DEVICE_ONLY__
581
576
#define OP (type, op ) \
582
- friend type operator op ( \
583
- const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \
584
- NumCols, Layout, Group> &, \
585
- const sycl::ext::oneapi::experimental::bfloat16 &) { \
577
+ friend type operator op (const wi_element<sycl::ext::oneapi::bfloat16, \
578
+ NumRows, NumCols, Layout, Group> &, \
579
+ const sycl::ext::oneapi::bfloat16 &) { \
586
580
throw runtime_error (" joint matrix is not supported on host device." , \
587
581
PI_ERROR_INVALID_DEVICE); \
588
582
} \
589
583
friend type operator op ( \
590
- const sycl::ext::oneapi::experimental:: bfloat16 &, \
591
- const wi_element<sycl::ext::oneapi::experimental:: bfloat16, NumRows, \
592
- NumCols, Layout, Group> &) { \
584
+ const sycl::ext::oneapi::bfloat16 &, \
585
+ const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Layout, \
586
+ Group> &) { \
593
587
throw runtime_error (" joint matrix is not supported on host device." , \
594
588
PI_ERROR_INVALID_DEVICE); \
595
589
}
596
- OP (sycl::ext::oneapi::experimental:: bfloat16, +)
597
- OP(sycl::ext::oneapi::experimental:: bfloat16, -)
598
- OP(sycl::ext::oneapi::experimental:: bfloat16, *)
599
- OP(sycl::ext::oneapi::experimental:: bfloat16, /)
590
+ OP (sycl::ext::oneapi::bfloat16, +)
591
+ OP(sycl::ext::oneapi::bfloat16, -)
592
+ OP(sycl::ext::oneapi::bfloat16, *)
593
+ OP(sycl::ext::oneapi::bfloat16, /)
600
594
OP(bool , ==)
601
595
OP(bool , !=)
602
596
OP(bool , <)
0 commit comments