|
11 | 11 | #include <CL/__spirv/spirv_ops.hpp>
|
12 | 12 | #include <CL/sycl/detail/defines_elementary.hpp>
|
13 | 13 | #include <CL/sycl/feature_test.hpp>
|
14 |
| -#include <sycl/ext/oneapi/experimental/bfloat16.hpp> |
15 | 14 |
|
16 | 15 | __SYCL_INLINE_NAMESPACE(cl) {
|
17 | 16 | namespace sycl {
|
@@ -454,156 +453,6 @@ class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {
|
454 | 453 | #undef OP
|
455 | 454 | };
|
456 | 455 |
|
457 |
| -template <size_t NumRows, size_t NumCols, matrix_layout Layout, typename Group> |
458 |
| -class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols, |
459 |
| - Layout, Group> { |
460 |
| - joint_matrix<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols, |
461 |
| - Layout, Group> &M; |
462 |
| - std::size_t idx; |
463 |
| - |
464 |
| -public: |
465 |
| - wi_element(joint_matrix<sycl::ext::oneapi::experimental::bfloat16, NumRows, |
466 |
| - NumCols, Layout, Group> &Mat, |
467 |
| - std::size_t i) |
468 |
| - : M(Mat), idx(i) {} |
469 |
| - operator sycl::ext::oneapi::experimental::bfloat16() { |
470 |
| -#ifdef __SYCL_DEVICE_ONLY__ |
471 |
| - return __spirv_VectorExtractDynamic(M.spvm, idx); |
472 |
| -#else |
473 |
| - throw runtime_error("joint matrix is not supported on host device.", |
474 |
| - PI_INVALID_DEVICE); |
475 |
| -#endif // __SYCL_DEVICE_ONLY__ |
476 |
| - } |
477 |
| - |
478 |
| - explicit operator bool() { |
479 |
| -#ifdef __SYCL_DEVICE_ONLY__ |
480 |
| - return std::fabs(static_cast<float>(__spirv_VectorExtractDynamic( |
481 |
| - M.spvm, idx))) >= std::numeric_limits<float>::epsilon(); |
482 |
| -#else |
483 |
| - throw runtime_error("joint matrix is not supported on host device.", |
484 |
| - PI_INVALID_DEVICE); |
485 |
| -#endif // __SYCL_DEVICE_ONLY__ |
486 |
| - } |
487 |
| - |
488 |
| - wi_element &operator=(const sycl::ext::oneapi::experimental::bfloat16 &rhs) { |
489 |
| -#ifdef __SYCL_DEVICE_ONLY__ |
490 |
| - M.spvm = __spirv_VectorInsertDynamic(M.spvm, rhs, idx); |
491 |
| - return *this; |
492 |
| -#else |
493 |
| - (void)rhs; |
494 |
| - throw runtime_error("joint matrix is not supported on host device.", |
495 |
| - PI_INVALID_DEVICE); |
496 |
| -#endif // __SYCL_DEVICE_ONLY__ |
497 |
| - } |
498 |
| - |
499 |
| - wi_element & |
500 |
| - operator=(const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, |
501 |
| - NumCols, Layout, Group> &rhs) { |
502 |
| -#ifdef __SYCL_DEVICE_ONLY__ |
503 |
| - M.spvm = __spirv_VectorInsertDynamic( |
504 |
| - M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx); |
505 |
| - return *this; |
506 |
| -#else |
507 |
| - (void)rhs; |
508 |
| - throw runtime_error("joint matrix is not supported on host device.", |
509 |
| - PI_INVALID_DEVICE); |
510 |
| -#endif // __SYCL_DEVICE_ONLY__ |
511 |
| - } |
512 |
| - |
513 |
| -#if __SYCL_DEVICE_ONLY__ |
514 |
| -#define OP(opassign, op) \ |
515 |
| - wi_element &operator opassign( \ |
516 |
| - const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ |
517 |
| - M.spvm = __spirv_VectorInsertDynamic( \ |
518 |
| - M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) op rhs, idx); \ |
519 |
| - return *this; \ |
520 |
| - } |
521 |
| -#else // __SYCL_DEVICE_ONLY__ |
522 |
| -#define OP(opassign, op) \ |
523 |
| - wi_element &operator opassign( \ |
524 |
| - const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ |
525 |
| - (void)rhs; \ |
526 |
| - throw runtime_error("joint matrix is not supported on host device.", \ |
527 |
| - PI_INVALID_DEVICE); \ |
528 |
| - } |
529 |
| -#endif // __SYCL_DEVICE_ONLY__ |
530 |
| - OP(+=, +) |
531 |
| - OP(-=, -) |
532 |
| - OP(*=, *) |
533 |
| - OP(/=, /) |
534 |
| -#undef OP |
535 |
| - |
536 |
| -#if __SYCL_DEVICE_ONLY__ |
537 |
| -#define OP(type, op) \ |
538 |
| - friend type operator op( \ |
539 |
| - const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \ |
540 |
| - NumCols, Layout, Group> &lhs, \ |
541 |
| - const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ |
542 |
| - return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) op rhs; \ |
543 |
| - } \ |
544 |
| - friend type operator op( \ |
545 |
| - const sycl::ext::oneapi::experimental::bfloat16 &lhs, \ |
546 |
| - const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \ |
547 |
| - NumCols, Layout, Group> &rhs) { \ |
548 |
| - return __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx) op lhs; \ |
549 |
| - } |
550 |
| - OP(sycl::ext::oneapi::experimental::bfloat16, +) |
551 |
| - OP(sycl::ext::oneapi::experimental::bfloat16, -) |
552 |
| - OP(sycl::ext::oneapi::experimental::bfloat16, *) |
553 |
| - OP(sycl::ext::oneapi::experimental::bfloat16, /) |
554 |
| -#undef OP |
555 |
| -#define OP(type, op) \ |
556 |
| - friend type operator op( \ |
557 |
| - const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \ |
558 |
| - NumCols, Layout, Group> &lhs, \ |
559 |
| - const sycl::ext::oneapi::experimental::bfloat16 &rhs) { \ |
560 |
| - return type{static_cast<float>(__spirv_VectorExtractDynamic( \ |
561 |
| - lhs.M.spvm, lhs.idx)) op static_cast<float>(rhs)}; \ |
562 |
| - } \ |
563 |
| - friend type operator op( \ |
564 |
| - const sycl::ext::oneapi::experimental::bfloat16 &lhs, \ |
565 |
| - const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \ |
566 |
| - NumCols, Layout, Group> &rhs) { \ |
567 |
| - return type{static_cast<float>(__spirv_VectorExtractDynamic( \ |
568 |
| - rhs.M.spvm, rhs.idx)) op static_cast<float>(lhs)}; \ |
569 |
| - } |
570 |
| - OP(bool, ==) |
571 |
| - OP(bool, !=) |
572 |
| - OP(bool, <) |
573 |
| - OP(bool, >) |
574 |
| - OP(bool, <=) |
575 |
| - OP(bool, >=) |
576 |
| -#undef OP |
577 |
| -#else // __SYCL_DEVICE_ONLY__ |
578 |
| -#define OP(type, op) \ |
579 |
| - friend type operator op( \ |
580 |
| - const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \ |
581 |
| - NumCols, Layout, Group> &, \ |
582 |
| - const sycl::ext::oneapi::experimental::bfloat16 &) { \ |
583 |
| - throw runtime_error("joint matrix is not supported on host device.", \ |
584 |
| - PI_INVALID_DEVICE); \ |
585 |
| - } \ |
586 |
| - friend type operator op( \ |
587 |
| - const sycl::ext::oneapi::experimental::bfloat16 &, \ |
588 |
| - const wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, \ |
589 |
| - NumCols, Layout, Group> &) { \ |
590 |
| - throw runtime_error("joint matrix is not supported on host device.", \ |
591 |
| - PI_INVALID_DEVICE); \ |
592 |
| - } |
593 |
| - OP(sycl::ext::oneapi::experimental::bfloat16, +) |
594 |
| - OP(sycl::ext::oneapi::experimental::bfloat16, -) |
595 |
| - OP(sycl::ext::oneapi::experimental::bfloat16, *) |
596 |
| - OP(sycl::ext::oneapi::experimental::bfloat16, /) |
597 |
| - OP(bool, ==) |
598 |
| - OP(bool, !=) |
599 |
| - OP(bool, <) |
600 |
| - OP(bool, >) |
601 |
| - OP(bool, <=) |
602 |
| - OP(bool, >=) |
603 |
| -#undef OP |
604 |
| -#endif // __SYCL_DEVICE_ONLY__ |
605 |
| -}; |
606 |
| - |
607 | 456 | template <typename T, size_t NumRows, size_t NumCols, matrix_layout Layout,
|
608 | 457 | typename Group>
|
609 | 458 | class wi_data {
|
|
0 commit comments