@@ -40,7 +40,8 @@ struct joint_matrix {
40
40
41
41
#if defined(__SYCL_DEVICE_ONLY__)
42
42
#if defined(__NVPTX__)
43
- sycl::ext::oneapi::detail::joint_matrix_cuda<T, Use, Rows, Cols, Layout>
43
+ mutable sycl::ext::oneapi::detail::joint_matrix_cuda<T, Use, Rows, Cols,
44
+ Layout>
44
45
cuda_impl;
45
46
#elif defined(__SPIR__)
46
47
__spv::__spirv_JointMatrixINTEL<
@@ -61,19 +62,8 @@ struct joint_matrix {
61
62
}
62
63
#ifdef __SYCL_DEVICE_ONLY__
63
64
#if defined(__SPIR__)
64
- // Generate a non-trivial assignment operator and copy c'tor that prevents
65
- // memcpy from being generated.
66
- // TODO: to remove, when either IGC can handle alloca JointMatrix or
67
- // combination of InstCombine + SROA + mem2reg can remove it
68
- joint_matrix (const joint_matrix &other) {
69
- spvm = other.spvm ;
70
- return *this ;
71
- }
72
-
73
- joint_matrix &operator =(const joint_matrix &rhs) {
74
- spvm = rhs.spvm ;
75
- return *this ;
76
- }
65
+ joint_matrix (const joint_matrix &other) = delete;
66
+ joint_matrix &operator =(const joint_matrix &rhs) = delete ;
77
67
#endif // defined(__SPIR__)
78
68
#endif
79
69
};
@@ -97,10 +87,6 @@ class wi_data {
97
87
size_t length () {
98
88
#if defined(__NVPTX__)
99
89
return jm.cuda_impl .wi_marray .size ();
100
- #else
101
- throw runtime_error (" get_wi_data is available using: "
102
- " ext::intel::experimental::matrix::get_wi_data." ,
103
- PI_ERROR_INVALID_DEVICE);
104
90
#endif
105
91
};
106
92
@@ -109,9 +95,6 @@ class wi_data {
109
95
return (jm.cuda_impl .wi_marray [i]);
110
96
#else
111
97
std::ignore = i;
112
- throw runtime_error (" get_wi_data is available using: "
113
- " ext::intel::experimental::matrix::get_wi_data." ,
114
- PI_ERROR_INVALID_DEVICE);
115
98
#endif
116
99
};
117
100
};
@@ -139,9 +122,8 @@ template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
139
122
__SYCL2020_DEPRECATED (" get_wi_data() is deprecated for CUDA backend. Please "
140
123
" use joint_matrix_apply() instead." )
141
124
#else
142
- __attribute__ ((unavailable(
143
- " get_wi_data can't be used on intel device, please use "
144
- " sycl::ext::intel::experimental::matrix::get_wi_data instead!" )))
125
+ __attribute__ ((unavailable(" get_wi_data() has been removed from the API and "
126
+ " replaced with joint_matrix_apply!" )))
145
127
#endif
146
128
#endif
147
129
inline __SYCL_ALWAYS_INLINE decltype (auto )
@@ -177,7 +159,7 @@ joint_matrix_apply(Group sg, joint_matrix<Group, T, Use, M, N, Layout> &jm,
177
159
using storage_element_type =
178
160
typename oneapi::detail::jm_type_interpretation_helper_trait<
179
161
T>::storage_element_type;
180
- auto wi_data_c = sycl::ext::intel::experimental::matrix ::get_wi_data (sg, jm);
162
+ auto wi_data_c = sycl::ext::oneapi::detail ::get_wi_data (sg, jm);
181
163
for (int i = 0 ; i < wi_data_c.length (); i++) {
182
164
storage_element_type element = wi_data_c[i];
183
165
lambda (element);
@@ -260,7 +242,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
260
242
Ptr, stride, __spv::MatrixLayout::ColumnMajor,
261
243
spv_scope_traits<Group>::value);
262
244
break ;
263
- case sycl::ext::intel::experimental::matrix:: layout::packed :
245
+ case layout::ext_intel_packed :
264
246
res.spvm = __spirv_JointMatrixLoadINTEL<
265
247
DecorT, S, NumRows, NumCols,
266
248
spv_matrix_use_traits<use::accumulator>::value,
@@ -322,8 +304,9 @@ template <typename Group, typename T, size_t NumRows, size_t NumCols,
322
304
access::address_space Space, access::decorated IsDecorated>
323
305
inline __SYCL_ALWAYS_INLINE void joint_matrix_store (
324
306
Group,
325
- joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
326
- sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
307
+ const joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
308
+ sycl::ext::oneapi::experimental::matrix::layout::dynamic>
309
+ &src,
327
310
multi_ptr<T, Space, IsDecorated> dst, size_t stride,
328
311
sycl::ext::oneapi::experimental::matrix::layout Layout) {
329
312
#if defined(__SYCL_DEVICE_ONLY__)
@@ -355,7 +338,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
355
338
Ptr, src.spvm , stride, __spv::MatrixLayout::ColumnMajor,
356
339
spv_scope_traits<Group>::value);
357
340
break ;
358
- case sycl::ext::intel::experimental::matrix:: layout::packed :
341
+ case layout::ext_intel_packed :
359
342
__spirv_JointMatrixStoreINTEL<
360
343
DecorT, T, NumRows, NumCols,
361
344
spv_matrix_use_traits<use::accumulator>::value,
@@ -375,51 +358,77 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
375
358
#endif // defined(__SYCL_DEVICE_ONLY__)
376
359
}
377
360
378
- template <typename Group, typename Ta, typename Tb, typename Tc, std::size_t M,
379
- std::size_t K, std::size_t N, layout LayoutA, layout LayoutB>
380
- inline __SYCL_ALWAYS_INLINE
381
- joint_matrix<Group, Tc, use::accumulator, M, N,
382
- sycl::ext::oneapi::experimental::matrix::layout::dynamic>
383
- joint_matrix_mad (
384
- Group, joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A,
385
- joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B,
386
- joint_matrix<Group, Tc, use::accumulator, M, N,
387
- sycl::ext::oneapi::experimental::matrix::layout::dynamic>
388
- &C) {
361
+ template <typename Group, typename Ta, typename Tb, typename Tc, typename Td,
362
+ std::size_t M, std::size_t K, std::size_t N, layout LayoutA,
363
+ layout LayoutB>
364
+ inline __SYCL_ALWAYS_INLINE void joint_matrix_mad (
365
+ Group,
366
+ joint_matrix<Group, Td, use::accumulator, M, N,
367
+ sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
368
+ const joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A,
369
+ const joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B,
370
+ const joint_matrix<Group, Tc, use::accumulator, M, N,
371
+ sycl::ext::oneapi::experimental::matrix::layout::dynamic>
372
+ &C) {
389
373
#if defined(__SYCL_DEVICE_ONLY__)
390
374
#if defined(__NVPTX__)
391
375
if constexpr (std::is_same<Ta, Tb>::value) {
392
- joint_matrix<Group, Tc, use::accumulator, M, N,
393
- sycl::ext::oneapi::experimental::matrix::layout::dynamic>
394
- D;
395
376
sycl::ext::oneapi::detail::joint_matrix_mad_cuda<Ta, Tc, M, K, N, LayoutA,
396
377
LayoutB>(
397
378
D.cuda_impl , A.cuda_impl , B.cuda_impl , C.cuda_impl );
398
- return D;
399
379
} else {
400
380
assert (false && " Ta != Tb : In the CUDA backend joint_matrix_mad "
401
381
" requires that joint_matrix data types Ta and Tb match" );
402
382
}
403
383
#else
404
- joint_matrix<Group, Tc, use::accumulator, M, N, layout::dynamic> res;
405
384
if constexpr (std::is_same<Ta, uint16_t >::value &&
406
385
std::is_same<Tb, uint16_t >::value &&
407
386
std::is_same<Tc, float >::value)
408
- res .spvm = __spirv_JointMatrixMadINTEL (A.spvm , B.spvm , C.spvm );
387
+ D .spvm = __spirv_JointMatrixMadINTEL (A.spvm , B.spvm , C.spvm );
409
388
else if constexpr (std::is_unsigned<Ta>::value && std::is_unsigned<Tb>::value)
410
- res .spvm = __spirv_JointMatrixUUMadINTEL (A.spvm , B.spvm , C.spvm );
389
+ D .spvm = __spirv_JointMatrixUUMadINTEL (A.spvm , B.spvm , C.spvm );
411
390
else if constexpr (std::is_signed<Ta>::value && std::is_unsigned<Tb>::value)
412
- res .spvm = __spirv_JointMatrixSUMadINTEL (A.spvm , B.spvm , C.spvm );
391
+ D .spvm = __spirv_JointMatrixSUMadINTEL (A.spvm , B.spvm , C.spvm );
413
392
else if constexpr (std::is_unsigned<Ta>::value && std::is_signed<Tb>::value)
414
- res .spvm = __spirv_JointMatrixUSMadINTEL (A.spvm , B.spvm , C.spvm );
393
+ D .spvm = __spirv_JointMatrixUSMadINTEL (A.spvm , B.spvm , C.spvm );
415
394
else
416
- res.spvm = __spirv_JointMatrixMadINTEL (A.spvm , B.spvm , C.spvm );
417
- return res;
395
+ D.spvm = __spirv_JointMatrixMadINTEL (A.spvm , B.spvm , C.spvm );
418
396
#endif // defined(__NVPTX__)
419
397
#else
420
398
std::ignore = A;
421
399
std::ignore = B;
422
400
std::ignore = C;
401
+ std::ignore = D;
402
+ throw runtime_error (" joint matrix is not supported on host device." ,
403
+ PI_ERROR_INVALID_DEVICE);
404
+ #endif // defined(__SYCL_DEVICE_ONLY__)
405
+ }
406
+
407
+ template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
408
+ use Use1, use Use2, layout Layout1, layout Layout2>
409
+ void joint_matrix_copy (
410
+ Group sg, joint_matrix<Group, T1, Use1, Rows, Cols, Layout1> &src,
411
+ joint_matrix<Group, T2, Use2, Rows, Cols, Layout2> &dst) {
412
+ #if defined(__SYCL_DEVICE_ONLY__)
413
+ #if defined(__NVPTX__)
414
+ std::ignore = sg;
415
+ for (int i = 0 ; i < src.cuda_impl .wi_marray .size (); i++) {
416
+ dst.cuda_impl .wi_marray [i] = src.cuda_impl .wi_marray [i];
417
+ }
418
+ #else
419
+ using storage_element_type =
420
+ typename oneapi::detail::jm_type_interpretation_helper_trait<
421
+ T2>::storage_element_type;
422
+ auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data (sg, src);
423
+ auto wi_data_dst = sycl::ext::oneapi::detail::get_wi_data (sg, dst);
424
+ for (int i = 0 ; i < wi_data_c.length (); i++) {
425
+ wi_data_dst[i] = static_cast <storage_element_type>(wi_data_c[i]);
426
+ }
427
+ #endif // defined(__NVPTX__)
428
+ #else
429
+ std::ignore = sg;
430
+ std::ignore = dst;
431
+ std::ignore = src;
423
432
throw runtime_error (" joint matrix is not supported on host device." ,
424
433
PI_ERROR_INVALID_DEVICE);
425
434
#endif // defined(__SYCL_DEVICE_ONLY__)
0 commit comments