@@ -504,20 +504,19 @@ template <typename Group, typename T, matrix_use Use, size_t NumRows,
504
504
void joint_matrix_load (
505
505
Group sg, joint_matrix<T, Use, NumRows, NumCols, Layout, Group> &res,
506
506
multi_ptr<T, Space> src, size_t stride) {
507
- #ifdef __SYCL_DEVICE_ONLY__
508
- #ifdef __NVPTX__
507
+ #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
509
508
sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols,
510
509
Layout, Space>{}
511
510
.load (res, src, stride);
512
- #endif
513
511
#else
514
512
(void )sg;
515
513
(void )res;
516
514
(void )src;
517
515
(void )stride;
518
- throw runtime_error (" joint_matrix_load is not supported on host device." ,
516
+ throw runtime_error (" When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_load is "
517
+ " only supported by CUDA devices" ,
519
518
PI_INVALID_DEVICE);
520
- #endif // __SYCL_DEVICE_ONLY__*/
519
+ #endif // defined( __SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
521
520
}
522
521
523
522
template <typename Group, typename T, size_t NumRows, size_t NumCols,
@@ -526,20 +525,19 @@ void joint_matrix_store(Group sg,
526
525
joint_matrix<T, matrix_use::accumulator, NumRows,
527
526
NumCols, Layout, Group> &src,
528
527
multi_ptr<T, Space> dst, size_t stride) {
529
- #ifdef __SYCL_DEVICE_ONLY__
530
- #ifdef __NVPTX__
531
- sycl::ext::oneapi::detail::joint_matrix_store_impl<T, NumRows, NumCols,
532
- Layout, Space>{}
528
+ #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
529
+ sycl::ext::oneapi::detail::joint_matrix_store_impl<T, NumRows, NumCols, Layout,
530
+ Space>{}
533
531
.store (src, dst, stride);
534
- #endif
535
532
#else
536
533
(void )sg;
537
534
(void )src;
538
535
(void )dst;
539
536
(void )stride;
540
- throw runtime_error (" joint_matrix_store is not supported on host device." ,
537
+ throw runtime_error (" When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_store is "
538
+ " only supported by CUDA devices" ,
541
539
PI_INVALID_DEVICE);
542
- #endif // __SYCL_DEVICE_ONLY__*/
540
+ #endif // defined( __SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
543
541
}
544
542
545
543
template <typename Group, typename T1, typename T2, std::size_t M,
@@ -550,20 +548,19 @@ joint_matrix_mad(
550
548
Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group> A,
551
549
joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group> B,
552
550
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> C) {
553
- #ifdef __SYCL_DEVICE_ONLY__
554
- #ifdef __NVPTX__
551
+ #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
555
552
return sycl::ext::oneapi::detail::joint_matrix_mad_impl<
556
553
T1, T2, M, K, N, LayoutA, LayoutB, LayoutC>{}
557
554
.mad (A, B, C);
558
- #endif
559
555
#else
560
556
(void )sg;
561
557
(void )A;
562
558
(void )B;
563
559
(void )C;
564
- throw runtime_error (" joint_matrix_mad is not supported on host device." ,
560
+ throw runtime_error (" When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_mad is "
561
+ " only supported by CUDA devices" ,
565
562
PI_INVALID_DEVICE);
566
- #endif // __SYCL_DEVICE_ONLY__*/
563
+ #endif // defined( __SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
567
564
}
568
565
569
566
} // namespace experimental::matrix
0 commit comments