Skip to content

Commit 2c0407f

Browse files
committed
Add error if cuda device not used.
Signed-off-by: jack.kirk <[email protected]>
1 parent 90fcd0e commit 2c0407f

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -504,20 +504,19 @@ template <typename Group, typename T, matrix_use Use, size_t NumRows,
504504
void joint_matrix_load(
505505
Group sg, joint_matrix<T, Use, NumRows, NumCols, Layout, Group> &res,
506506
multi_ptr<T, Space> src, size_t stride) {
507-
#ifdef __SYCL_DEVICE_ONLY__
508-
#ifdef __NVPTX__
507+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
509508
sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols,
510509
Layout, Space>{}
511510
.load(res, src, stride);
512-
#endif
513511
#else
514512
(void)sg;
515513
(void)res;
516514
(void)src;
517515
(void)stride;
518-
throw runtime_error("joint_matrix_load is not supported on host device.",
516+
throw runtime_error("When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_load is "
517+
"only supported by CUDA devices",
519518
PI_INVALID_DEVICE);
520-
#endif // __SYCL_DEVICE_ONLY__*/
519+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
521520
}
522521

523522
template <typename Group, typename T, size_t NumRows, size_t NumCols,
@@ -526,20 +525,19 @@ void joint_matrix_store(Group sg,
526525
joint_matrix<T, matrix_use::accumulator, NumRows,
527526
NumCols, Layout, Group> &src,
528527
multi_ptr<T, Space> dst, size_t stride) {
529-
#ifdef __SYCL_DEVICE_ONLY__
530-
#ifdef __NVPTX__
531-
sycl::ext::oneapi::detail::joint_matrix_store_impl<T, NumRows, NumCols,
532-
Layout, Space>{}
528+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
529+
sycl::ext::oneapi::detail::joint_matrix_store_impl<T, NumRows, NumCols, Layout,
530+
Space>{}
533531
.store(src, dst, stride);
534-
#endif
535532
#else
536533
(void)sg;
537534
(void)src;
538535
(void)dst;
539536
(void)stride;
540-
throw runtime_error("joint_matrix_store is not supported on host device.",
537+
throw runtime_error("When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_store is "
538+
"only supported by CUDA devices",
541539
PI_INVALID_DEVICE);
542-
#endif // __SYCL_DEVICE_ONLY__*/
540+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
543541
}
544542

545543
template <typename Group, typename T1, typename T2, std::size_t M,
@@ -550,20 +548,19 @@ joint_matrix_mad(
550548
Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group> A,
551549
joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group> B,
552550
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> C) {
553-
#ifdef __SYCL_DEVICE_ONLY__
554-
#ifdef __NVPTX__
551+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
555552
return sycl::ext::oneapi::detail::joint_matrix_mad_impl<
556553
T1, T2, M, K, N, LayoutA, LayoutB, LayoutC>{}
557554
.mad(A, B, C);
558-
#endif
559555
#else
560556
(void)sg;
561557
(void)A;
562558
(void)B;
563559
(void)C;
564-
throw runtime_error("joint_matrix_mad is not supported on host device.",
560+
throw runtime_error("When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_mad is "
561+
"only supported by CUDA devices",
565562
PI_INVALID_DEVICE);
566-
#endif // __SYCL_DEVICE_ONLY__*/
563+
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
567564
}
568565

569566
} // namespace experimental::matrix

0 commit comments

Comments
 (0)