24
24
#include < sycl/exception.hpp> // for runtime_error
25
25
#include < sycl/ext/oneapi/matrix/matrix-unified-utils.hpp> // for layout, use, tf32, convertMatrixUseEnumToString
26
26
#include < sycl/ext/oneapi/matrix/query-types.hpp> // for convertTypeToMatrixTypeString
27
- #include < sycl/marray.hpp> // for marray
28
- #include < sycl/multi_ptr.hpp> // for multi_ptr
27
+ #include < sycl/marray.hpp> // for marray
28
+ #include < sycl/multi_ptr.hpp> // for multi_ptr
29
29
30
30
#include < cstring> // for size_t, memcpy
31
31
#include < stdint.h> // for uint32_t
@@ -165,34 +165,12 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
165
165
std::ignore = sg;
166
166
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
167
167
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(src);
168
- switch (Layout) {
169
- default :
170
- assert (false && " Invalid Memory Layout!" );
171
- case layout::row_major:
172
- res.spvm = __spirv_JointMatrixLoadINTEL<
173
- DecorT, S, NumRows, NumCols,
174
- spv_matrix_use_traits<use::accumulator>::value,
175
- spv_matrix_layout_traits<layout::dynamic>::value>(
176
- Ptr, stride, __spv::MatrixLayout::RowMajor,
177
- spv_scope_traits<Group>::value);
178
- break ;
179
- case layout::col_major:
180
- res.spvm = __spirv_JointMatrixLoadINTEL<
181
- DecorT, S, NumRows, NumCols,
182
- spv_matrix_use_traits<use::accumulator>::value,
183
- spv_matrix_layout_traits<layout::dynamic>::value>(
184
- Ptr, stride, __spv::MatrixLayout::ColumnMajor,
185
- spv_scope_traits<Group>::value);
186
- break ;
187
- case layout::ext_intel_packed:
188
- res.spvm = __spirv_JointMatrixLoadINTEL<
189
- DecorT, S, NumRows, NumCols,
190
- spv_matrix_use_traits<use::accumulator>::value,
191
- spv_matrix_layout_traits<layout::dynamic>::value>(
192
- Ptr, stride, __spv::MatrixLayout::Packed,
193
- spv_scope_traits<Group>::value);
194
- break ;
195
- }
168
+ res.spvm = __spirv_JointMatrixLoadINTEL<
169
+ DecorT, S, NumRows, NumCols,
170
+ spv_matrix_use_traits<use::accumulator>::value,
171
+ spv_matrix_layout_traits<layout::dynamic>::value>(
172
+ Ptr, stride, sycl::detail::joint_matrix_layout_to_spv (Layout),
173
+ spv_scope_traits<Group>::value);
196
174
#endif // defined(__NVPTX__)
197
175
#else
198
176
std::ignore = sg;
@@ -250,6 +228,83 @@ joint_matrix_load(Group sg,
250
228
#endif // defined(__SYCL_DEVICE_ONLY__)
251
229
}
252
230
231
+ template <typename Group, typename S, typename T, size_t NumRows,
232
+ size_t NumCols, typename PropertyListT,
233
+ std::enable_if_t <std::is_same<S, std::remove_const_t <T>>::value,
234
+ bool > = true >
235
+ inline __SYCL_ALWAYS_INLINE void joint_matrix_load (
236
+ Group sg,
237
+ joint_matrix<Group, S, use::accumulator, NumRows, NumCols,
238
+ sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
239
+ ext::oneapi::experimental::annotated_ptr<T, PropertyListT> src,
240
+ size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout) {
241
+ #if defined(__SYCL_DEVICE_ONLY__)
242
+ #if defined(__NVPTX__)
243
+ std::ignore = sg;
244
+ throw runtime_error (" Use joint_matrix_load on multi_ptr on Nvidia device." ,
245
+ PI_ERROR_INVALID_DEVICE);
246
+ #elif defined(__HIP_PLATFORM_AMD_MFMA__)
247
+ throw runtime_error (" Use joint_matrix_load on multi_ptr on AMD device." ,
248
+ PI_ERROR_INVALID_DEVICE);
249
+ #else
250
+ std::ignore = sg;
251
+ T *Ptr = src.get ();
252
+ res.spvm = __spirv_JointMatrixLoadINTEL<
253
+ T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
254
+ spv_matrix_layout_traits<layout::dynamic>::value>(
255
+ Ptr, stride, sycl::detail::joint_matrix_layout_to_spv (Layout),
256
+ spv_scope_traits<Group>::value);
257
+ #endif // defined(__NVPTX__)
258
+ #else
259
+ std::ignore = sg;
260
+ std::ignore = res;
261
+ std::ignore = src;
262
+ std::ignore = stride;
263
+ std::ignore = Layout;
264
+ throw runtime_error (" joint matrix is not supported on host device." ,
265
+ PI_ERROR_INVALID_DEVICE);
266
+ #endif // defined(__SYCL_DEVICE_ONLY__)
267
+ }
268
+
269
+ template <
270
+ typename Group, typename S, typename T, use Use, size_t NumRows,
271
+ size_t NumCols, matrix::layout Layout, typename PropertyListT,
272
+ std::enable_if_t <std::is_same<S, std::remove_const_t <T>>::value ||
273
+ (std::is_same<S, precision::tf32>::value &&
274
+ std::is_same<std::remove_const_t <T>, float >::value),
275
+ bool > = true >
276
+ inline __SYCL_ALWAYS_INLINE void joint_matrix_load (
277
+ Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &res,
278
+ ext::oneapi::experimental::annotated_ptr<T, PropertyListT> src,
279
+ size_t stride) {
280
+ #if defined(__SYCL_DEVICE_ONLY__)
281
+ #if defined(__NVPTX__)
282
+ std::ignore = sg;
283
+ throw runtime_error (" Use joint_matrix_load on multi_ptr on Nvidia device." ,
284
+ PI_ERROR_INVALID_DEVICE);
285
+ #elif defined(__HIP_PLATFORM_AMD_MFMA__)
286
+ throw runtime_error (" Use joint_matrix_load on multi_ptr on AMD device." ,
287
+ PI_ERROR_INVALID_DEVICE);
288
+ #else
289
+ std::ignore = sg;
290
+ T *Ptr = src.get ();
291
+ res.spvm =
292
+ __spirv_JointMatrixLoadINTEL<T, S, NumRows, NumCols,
293
+ spv_matrix_use_traits<Use>::value,
294
+ spv_matrix_layout_traits<Layout>::value>(
295
+ Ptr, stride, spv_matrix_layout_traits<Layout>::value,
296
+ spv_scope_traits<Group>::value);
297
+ #endif // defined(__NVPTX__)
298
+ #else
299
+ std::ignore = sg;
300
+ std::ignore = res;
301
+ std::ignore = src;
302
+ std::ignore = stride;
303
+ throw runtime_error (" joint matrix is not supported on host device." ,
304
+ PI_ERROR_INVALID_DEVICE);
305
+ #endif // defined(__SYCL_DEVICE_ONLY__)
306
+ }
307
+
253
308
template <typename Group, typename T, size_t NumRows, size_t NumCols,
254
309
access::address_space Space, access::decorated IsDecorated>
255
310
inline __SYCL_ALWAYS_INLINE void joint_matrix_store (
@@ -275,34 +330,49 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
275
330
std::ignore = sg;
276
331
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
277
332
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(dst);
278
- switch (Layout) {
279
- default :
280
- assert (false && " Invalid Memory Layout!" );
281
- case layout::row_major:
282
- __spirv_JointMatrixStoreINTEL<
283
- DecorT, T, NumRows, NumCols,
284
- spv_matrix_use_traits<use::accumulator>::value,
285
- spv_matrix_layout_traits<layout::dynamic>::value>(
286
- Ptr, src.spvm , stride, __spv::MatrixLayout::RowMajor,
287
- spv_scope_traits<Group>::value);
288
- break ;
289
- case layout::col_major:
290
- __spirv_JointMatrixStoreINTEL<
291
- DecorT, T, NumRows, NumCols,
292
- spv_matrix_use_traits<use::accumulator>::value,
293
- spv_matrix_layout_traits<layout::dynamic>::value>(
294
- Ptr, src.spvm , stride, __spv::MatrixLayout::ColumnMajor,
295
- spv_scope_traits<Group>::value);
296
- break ;
297
- case layout::ext_intel_packed:
298
- __spirv_JointMatrixStoreINTEL<
299
- DecorT, T, NumRows, NumCols,
300
- spv_matrix_use_traits<use::accumulator>::value,
301
- spv_matrix_layout_traits<layout::dynamic>::value>(
302
- Ptr, src.spvm , stride, __spv::MatrixLayout::Packed,
303
- spv_scope_traits<Group>::value);
304
- break ;
305
- }
333
+ __spirv_JointMatrixStoreINTEL<
334
+ DecorT, T, NumRows, NumCols,
335
+ spv_matrix_use_traits<use::accumulator>::value,
336
+ spv_matrix_layout_traits<layout::dynamic>::value>(
337
+ Ptr, src.spvm , stride, sycl::detail::joint_matrix_layout_to_spv (Layout),
338
+ spv_scope_traits<Group>::value);
339
+ #endif // defined(__NVPTX__)
340
+ #else
341
+ std::ignore = sg;
342
+ std::ignore = src;
343
+ std::ignore = dst;
344
+ std::ignore = stride;
345
+ std::ignore = Layout;
346
+ throw runtime_error (" joint matrix is not supported on host device." ,
347
+ PI_ERROR_INVALID_DEVICE);
348
+ #endif // defined(__SYCL_DEVICE_ONLY__)
349
+ }
350
+
351
+ template <typename Group, typename T, size_t NumRows, size_t NumCols,
352
+ typename PropertyListT>
353
+ inline __SYCL_ALWAYS_INLINE void joint_matrix_store (
354
+ Group sg,
355
+ const joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
356
+ sycl::ext::oneapi::experimental::matrix::layout::dynamic>
357
+ &src,
358
+ ext::oneapi::experimental::annotated_ptr<T, PropertyListT> dst,
359
+ size_t stride, sycl::ext::oneapi::experimental::matrix::layout Layout) {
360
+ #if defined(__SYCL_DEVICE_ONLY__)
361
+ #if defined(__NVPTX__)
362
+ std::ignore = sg;
363
+ throw runtime_error (" Use joint_matrix_store on multi_ptr on Nvidia device." ,
364
+ PI_ERROR_INVALID_DEVICE);
365
+ #elif defined(__HIP_PLATFORM_AMD_MFMA__)
366
+ throw runtime_error (" Use joint_matrix_store on multi_ptr on AMD device." ,
367
+ PI_ERROR_INVALID_DEVICE);
368
+ #else
369
+ std::ignore = sg;
370
+ T *Ptr = dst.get ();
371
+ __spirv_JointMatrixStoreINTEL<
372
+ T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
373
+ spv_matrix_layout_traits<layout::dynamic>::value>(
374
+ Ptr, src.spvm , stride, sycl::detail::joint_matrix_layout_to_spv (Layout),
375
+ spv_scope_traits<Group>::value);
306
376
#endif // defined(__NVPTX__)
307
377
#else
308
378
std::ignore = sg;
@@ -429,6 +499,46 @@ inline __SYCL_ALWAYS_INLINE float round_to_tf32(const float &a) {
429
499
return ret;
430
500
#endif // defined(__SYCL_DEVICE_ONLY__)
431
501
}
502
+
503
+ template <size_t NumRows, size_t NumCols, typename Group, typename T,
504
+ typename Properties = ext::oneapi::experimental::empty_properties_t >
505
+ inline __SYCL_ALWAYS_INLINE void
506
+ joint_matrix_prefetch (Group sg, T *Ptr, size_t stride,
507
+ sycl::ext::oneapi::experimental::matrix::layout Layout,
508
+ Properties properties = {}) {
509
+ #if defined(__SYCL_DEVICE_ONLY__)
510
+ #if defined(__NVPTX__)
511
+ std::ignore = sg;
512
+ std::ignore = properties;
513
+ throw runtime_error (
514
+ " joint_matrix_prefetch is not supported on Nvidia device." ,
515
+ PI_ERROR_INVALID_DEVICE);
516
+ #elif defined(__HIP_PLATFORM_AMD_MFMA__)
517
+ std::ignore = sg;
518
+ std::ignore = properties;
519
+ throw runtime_error (" joint_matrix_prefetch is not supported on AMD device." ,
520
+ PI_ERROR_INVALID_DEVICE);
521
+ #else
522
+ std::ignore = sg;
523
+ auto prop = properties.template get_property <prefetch_hint_key>();
524
+ // Will be removed once SPIRV implementation also uses offsetpointer
525
+ size_t coordX = 0 ;
526
+ size_t coordY = 0 ;
527
+ __spirv_JointMatrixPrefetchINTEL<T, NumRows, NumCols>(
528
+ Ptr, coordX, coordY, detail::PropertyMetaInfo<decltype (prop)>::value,
529
+ sycl::detail::joint_matrix_layout_to_spv (Layout), stride);
530
+ #endif // defined(__NVPTX__)
531
+ #else
532
+ std::ignore = sg;
533
+ std::ignore = Ptr;
534
+ std::ignore = stride;
535
+ std::ignore = Layout;
536
+ std::ignore = properties;
537
+ throw runtime_error (" joint matrix is not supported on host device." ,
538
+ PI_ERROR_INVALID_DEVICE);
539
+ #endif // defined(__SYCL_DEVICE_ONLY__)
540
+ }
541
+
432
542
} // namespace matrix
433
543
} // namespace experimental
434
544
} // namespace oneapi
0 commit comments