@@ -38,19 +38,22 @@ __SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 32, int32_t, 8)
38
38
__SYCL_JOINT_MATRIX_OVERLOAD (half, a, 8 , 16 , int32_t , 8 )
39
39
__SYCL_JOINT_MATRIX_OVERLOAD (half, b, 16 , 32 , int32_t , 8 )
40
40
__SYCL_JOINT_MATRIX_OVERLOAD (float , accumulator, 8 , 32 , float , 8 )
41
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, accumulator, 8 , 32 , int32_t , 4 )
41
42
42
43
__SYCL_JOINT_MATRIX_OVERLOAD (int8_t , a, 8 , 16 , int32_t , 1 )
43
44
__SYCL_JOINT_MATRIX_OVERLOAD (int8_t , b, 16 , 32 , int32_t , 4 )
44
45
__SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , a, 8 , 16 , int32_t , 1 )
45
46
__SYCL_JOINT_MATRIX_OVERLOAD (uint8_t , b, 16 , 32 , int32_t , 4 )
46
47
__SYCL_JOINT_MATRIX_OVERLOAD (int32_t , accumulator, 8 , 32 , int32_t , 8 )
47
48
49
+
48
50
// m32n8k16
49
51
__SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , a, 32 , 16 , int32_t , 8 )
50
52
__SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , b, 16 , 8 , int32_t , 2 )
51
53
__SYCL_JOINT_MATRIX_OVERLOAD (half, a, 32 , 16 , int32_t , 8 )
52
54
__SYCL_JOINT_MATRIX_OVERLOAD (half, b, 16 , 8 , int32_t , 8 )
53
55
__SYCL_JOINT_MATRIX_OVERLOAD (float , accumulator, 32 , 8 , float , 8 )
56
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, accumulator, 32 , 8 , int32_t , 4 )
54
57
55
58
__SYCL_JOINT_MATRIX_OVERLOAD (int8_t , a, 32 , 16 , int32_t , 4 )
56
59
__SYCL_JOINT_MATRIX_OVERLOAD (int8_t , b, 16 , 8 , int32_t , 1 )
@@ -64,6 +67,7 @@ __SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 16, int32_t, 8)
64
67
__SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , a, 16 , 16 , int32_t , 4 )
65
68
__SYCL_JOINT_MATRIX_OVERLOAD (uint16_t , b, 16 , 16 , int32_t , 4 )
66
69
__SYCL_JOINT_MATRIX_OVERLOAD (float , accumulator, 16 , 16 , float , 8 )
70
+ __SYCL_JOINT_MATRIX_OVERLOAD (half, accumulator, 16 , 16 , int32_t , 4 )
67
71
68
72
__SYCL_JOINT_MATRIX_OVERLOAD (int8_t , a, 16 , 16 , int32_t , 2 )
69
73
__SYCL_JOINT_MATRIX_OVERLOAD (int8_t , b, 16 , 16 , int32_t , 2 )
@@ -82,7 +86,7 @@ template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use,
82
86
access::address_space Space, typename Cond = void >
83
87
struct joint_matrix_load_impl {
84
88
void load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
85
- T, Use, NumRows, NumCols, Layout> &res,
89
+ T, Use, NumRows, NumCols, Layout, sycl::sub_group > &res,
86
90
multi_ptr<T, Space> src, size_t stride);
87
91
};
88
92
@@ -112,11 +116,8 @@ struct joint_matrix_load_impl<
112
116
Layout == sycl::ext::oneapi::experimental::
113
117
matrix::matrix_layout::col_major>> {
114
118
void load (sycl::ext::oneapi::experimental::matrix::joint_matrix<
115
- T, Use, NumRows, NumCols, Layout> &res,
119
+ T, Use, NumRows, NumCols, Layout, sycl::sub_group > &res,
116
120
multi_ptr<T, Space> src, size_t stride) {
117
- #ifdef __NVPTX__
118
- #ifdef __SYCL_DEVICE_ONLY__
119
-
120
121
if constexpr (std::is_same<T, uint16_t >::value) {
121
122
int32_t *tileptr = reinterpret_cast <int32_t *>(src.get ());
122
123
if constexpr (NumRows == 16 && NumCols == 16 ) {
@@ -203,6 +204,10 @@ struct joint_matrix_load_impl<
203
204
matrix_use::b) {
204
205
__hmma_m16n16k16_ld_b (res.data , tileptr, stride,
205
206
get_layout_id<Layout>());
207
+ } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::
208
+ matrix_use::accumulator) {
209
+ __hmma_m16n16k16_ld_c_f16 (res.data , tileptr, stride,
210
+ get_layout_id<Layout>());
206
211
}
207
212
} else if constexpr (NumRows == 8 && NumCols == 16 ) {
208
213
__hmma_m8n32k16_ld_a (res.data , tileptr, stride,
@@ -216,7 +221,14 @@ struct joint_matrix_load_impl<
216
221
} else if constexpr (NumRows == 16 && NumCols == 8 ) {
217
222
__hmma_m32n8k16_ld_b (res.data , tileptr, stride,
218
223
get_layout_id<Layout>());
224
+ } else if constexpr (NumRows == 32 && NumCols == 8 ) {
225
+ __hmma_m32n8k16_ld_c_f16 (res.data , tileptr, stride,
226
+ get_layout_id<Layout>());
227
+ } else if constexpr (NumRows == 8 && NumCols == 32 ) {
228
+ __hmma_m8n32k16_ld_c_f16 (res.data , tileptr, stride,
229
+ get_layout_id<Layout>());
219
230
}
231
+
220
232
} else if constexpr (std::is_same<T, int32_t >::value) {
221
233
if constexpr (NumRows == 16 && NumCols == 16 ) {
222
234
__imma_m16n16k16_ld_c (res.data , src.get (), stride,
@@ -254,8 +266,6 @@ struct joint_matrix_load_impl<
254
266
get_layout_id<Layout>());
255
267
}
256
268
}
257
- #endif
258
- #endif
259
269
}
260
270
};
261
271
@@ -266,7 +276,7 @@ struct joint_matrix_store_impl {
266
276
void
267
277
store (sycl::ext::oneapi::experimental::matrix::joint_matrix<
268
278
T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
269
- NumRows, NumCols, Layout> &src,
279
+ NumRows, NumCols, Layout, sycl::sub_group > &src,
270
280
multi_ptr<T, Space> dst, size_t stride);
271
281
};
272
282
@@ -282,18 +292,19 @@ struct joint_matrix_store_impl<
282
292
void
283
293
store (sycl::ext::oneapi::experimental::matrix::joint_matrix<
284
294
T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
285
- NumRows, NumCols, Layout> &src,
295
+ NumRows, NumCols, Layout, sycl::sub_group > &src,
286
296
multi_ptr<T, Space> dst, size_t stride) {
287
-
288
- #ifdef __NVPTX__
289
- #ifdef __SYCL_DEVICE_ONLY__
290
297
if (NumRows == 16 && NumCols == 16 ) {
291
298
if constexpr (std::is_same<T, float >::value) {
292
299
__hmma_m16n16k16_st_c_f32 (dst.get (), src.data , stride,
293
300
get_layout_id<Layout>());
294
301
} else if constexpr (std::is_same<T, int32_t >::value) {
295
302
__imma_m16n16k16_st_c_i32 (dst.get (), src.data , stride,
296
303
get_layout_id<Layout>());
304
+ } else if constexpr (std::is_same<T, half>::value) {
305
+ int32_t *tileptr = reinterpret_cast <int32_t *>(dst.get ());
306
+ __hmma_m16n16k16_st_c_f16 (tileptr, src.data , stride,
307
+ get_layout_id<Layout>());
297
308
}
298
309
} else if (NumRows == 8 && NumCols == 32 ) {
299
310
if constexpr (std::is_same<T, float >::value) {
@@ -302,6 +313,10 @@ struct joint_matrix_store_impl<
302
313
} else if constexpr (std::is_same<T, int32_t >::value) {
303
314
__imma_m8n32k16_st_c_i32 (dst.get (), src.data , stride,
304
315
get_layout_id<Layout>());
316
+ } else if constexpr (std::is_same<T, half>::value) {
317
+ int32_t *tileptr = reinterpret_cast <int32_t *>(dst.get ());
318
+ __hmma_m8n32k16_st_c_f16 (tileptr, src.data , stride,
319
+ get_layout_id<Layout>());
305
320
}
306
321
} else if (NumRows == 32 && NumCols == 8 ) {
307
322
if constexpr (std::is_same<T, float >::value) {
@@ -310,14 +325,15 @@ struct joint_matrix_store_impl<
310
325
} else if constexpr (std::is_same<T, int32_t >::value) {
311
326
__imma_m32n8k16_st_c_i32 (dst.get (), src.data , stride,
312
327
get_layout_id<Layout>());
328
+ } else if constexpr (std::is_same<T, half>::value) {
329
+ int32_t *tileptr = reinterpret_cast <int32_t *>(dst.get ());
330
+ __hmma_m32n8k16_st_c_f16 (tileptr, src.data , stride,
331
+ get_layout_id<Layout>());
313
332
}
314
333
} else if constexpr (std::is_same<T, double >::value) {
315
334
__dmma_m8n8k4_st_c_f64 (dst.get (), src.data , stride,
316
335
get_layout_id<Layout>());
317
336
}
318
-
319
- #endif
320
- #endif
321
337
}
322
338
};
323
339
@@ -329,18 +345,18 @@ template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
329
345
struct joint_matrix_mad_impl {
330
346
sycl::ext::oneapi::experimental::matrix::joint_matrix<
331
347
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
332
- N, LayoutC>
348
+ N, LayoutC, sycl::sub_group >
333
349
mad (sycl::ext::oneapi::experimental::matrix::joint_matrix<
334
350
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
335
- LayoutA>
351
+ LayoutA, sycl::sub_group >
336
352
A,
337
353
sycl::ext::oneapi::experimental::matrix::joint_matrix<
338
354
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
339
- LayoutB>
355
+ LayoutB, sycl::sub_group >
340
356
B,
341
357
sycl::ext::oneapi::experimental::matrix::joint_matrix<
342
358
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
343
- M, N, LayoutC>
359
+ M, N, LayoutC, sycl::sub_group >
344
360
C);
345
361
};
346
362
@@ -397,26 +413,23 @@ struct joint_matrix_mad_impl<
397
413
col_major)>> {
398
414
sycl::ext::oneapi::experimental::matrix::joint_matrix<
399
415
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
400
- N, LayoutC>
416
+ N, LayoutC, sycl::sub_group >
401
417
mad (sycl::ext::oneapi::experimental::matrix::joint_matrix<
402
418
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
403
- LayoutA>
419
+ LayoutA, sycl::sub_group >
404
420
A,
405
421
sycl::ext::oneapi::experimental::matrix::joint_matrix<
406
422
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
407
- LayoutB>
423
+ LayoutB, sycl::sub_group >
408
424
B,
409
425
sycl::ext::oneapi::experimental::matrix::joint_matrix<
410
426
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
411
- M, N, LayoutC>
427
+ M, N, LayoutC, sycl::sub_group >
412
428
C) {
413
429
sycl::ext::oneapi::experimental::matrix::joint_matrix<
414
430
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
415
- N, LayoutC>
431
+ N, LayoutC, sycl::sub_group >
416
432
D;
417
-
418
- #ifdef __NVPTX__
419
- #ifdef __SYCL_DEVICE_ONLY__
420
433
if constexpr (M == 16 && N == 16 && K == 16 ) {
421
434
if constexpr (std::is_same<T1, int8_t >::value) {
422
435
__imma_m16n16k16_mma_s8 (D.data , A.data , B.data , C.data ,
@@ -425,8 +438,15 @@ struct joint_matrix_mad_impl<
425
438
__imma_m16n16k16_mma_u8 (D.data , A.data , B.data , C.data ,
426
439
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
427
440
} else if constexpr (std::is_same<T1, half>::value) {
428
- __hmma_m16n16k16_mma_f32f32 (D.data , A.data , B.data , C.data ,
429
- get_layout_pair_id<LayoutA, LayoutB>(), 0 );
441
+ if constexpr (std::is_same<T2, float >::value) {
442
+ __hmma_m16n16k16_mma_f32f32 (D.data , A.data , B.data , C.data ,
443
+ get_layout_pair_id<LayoutA, LayoutB>(),
444
+ 0 );
445
+ } else if constexpr (std::is_same<T2, half>::value) {
446
+ __hmma_m16n16k16_mma_f16f16 (D.data , A.data , B.data , C.data ,
447
+ get_layout_pair_id<LayoutA, LayoutB>(),
448
+ 0 );
449
+ }
430
450
} else if constexpr (std::is_same<T1, uint16_t >::value) {
431
451
__mma_bf16_m16n16k16_mma_f32 (D.data , A.data , B.data , C.data ,
432
452
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
@@ -439,8 +459,13 @@ struct joint_matrix_mad_impl<
439
459
__imma_m8n32k16_mma_u8 (D.data , A.data , B.data , C.data ,
440
460
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
441
461
} else if constexpr (std::is_same<T1, half>::value) {
442
- __hmma_m8n32k16_mma_f32f32 (D.data , A.data , B.data , C.data ,
443
- get_layout_pair_id<LayoutA, LayoutB>(), 0 );
462
+ if constexpr (std::is_same<T2, float >::value) {
463
+ __hmma_m8n32k16_mma_f32f32 (D.data , A.data , B.data , C.data ,
464
+ get_layout_pair_id<LayoutA, LayoutB>(), 0 );
465
+ } else if constexpr (std::is_same<T2, half>::value) {
466
+ __hmma_m8n32k16_mma_f16f16 (D.data , A.data , B.data , C.data ,
467
+ get_layout_pair_id<LayoutA, LayoutB>(), 0 );
468
+ }
444
469
} else if constexpr (std::is_same<T1, uint16_t >::value) {
445
470
__mma_bf16_m8n32k16_mma_f32 (D.data , A.data , B.data , C.data ,
446
471
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
@@ -456,16 +481,18 @@ struct joint_matrix_mad_impl<
456
481
__mma_bf16_m32n8k16_mma_f32 (D.data , A.data , B.data , C.data ,
457
482
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
458
483
} else if constexpr (std::is_same<T1, half>::value) {
459
- __hmma_m32n8k16_mma_f32f32 (D.data , A.data , B.data , C.data ,
460
- get_layout_pair_id<LayoutA, LayoutB>(), 0 );
484
+ if constexpr (std::is_same<T2, float >::value) {
485
+ __hmma_m32n8k16_mma_f32f32 (D.data , A.data , B.data , C.data ,
486
+ get_layout_pair_id<LayoutA, LayoutB>(), 0 );
487
+ } else if constexpr (std::is_same<T2, half>::value) {
488
+ __hmma_m32n8k16_mma_f16f16 (D.data , A.data , B.data , C.data ,
489
+ get_layout_pair_id<LayoutA, LayoutB>(), 0 );
490
+ }
461
491
}
462
492
} else if constexpr (std::is_same<T1, double >::value) {
463
493
__dmma_m8n8k4_mma_f64 (D.data , A.data , B.data , C.data ,
464
494
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
465
495
}
466
- #endif
467
- #endif
468
-
469
496
return D;
470
497
}
471
498
};
@@ -479,9 +506,20 @@ template <typename Group, typename T, matrix_use Use, size_t NumRows,
479
506
void joint_matrix_load (
480
507
Group sg, joint_matrix<T, Use, NumRows, NumCols, Layout, Group> &res,
481
508
multi_ptr<T, Space> src, size_t stride) {
482
- sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols,
509
+ #ifdef __SYCL_DEVICE_ONLY__
510
+ #ifdef __NVPTX__
511
+ sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols,
483
512
Layout, Space>{}
484
513
.load (res, src, stride);
514
+ #endif
515
+ #else
516
+ (void )sg;
517
+ (void )res;
518
+ (void )src;
519
+ (void )stride;
520
+ throw runtime_error (" joint_matrix_load is not supported on host device." ,
521
+ PI_INVALID_DEVICE);
522
+ #endif // __SYCL_DEVICE_ONLY__*/
485
523
}
486
524
487
525
template <typename Group, typename T, size_t NumRows, size_t NumCols,
@@ -490,9 +528,20 @@ void joint_matrix_store(Group sg,
490
528
joint_matrix<T, matrix_use::accumulator, NumRows,
491
529
NumCols, Layout, Group> &src,
492
530
multi_ptr<T, Space> dst, size_t stride) {
493
- sycl::ext::oneapi::detail::joint_matrix_store_impl<T, NumRows, NumCols,
531
+ #ifdef __SYCL_DEVICE_ONLY__
532
+ #ifdef __NVPTX__
533
+ sycl::ext::oneapi::detail::joint_matrix_store_impl<T, NumRows, NumCols,
494
534
Layout, Space>{}
495
535
.store (src, dst, stride);
536
+ #endif
537
+ #else
538
+ (void )sg;
539
+ (void )src;
540
+ (void )dst;
541
+ (void )stride;
542
+ throw runtime_error (" joint_matrix_store is not supported on host device." ,
543
+ PI_INVALID_DEVICE);
544
+ #endif // __SYCL_DEVICE_ONLY__*/
496
545
}
497
546
498
547
template <typename Group, typename T1, typename T2, std::size_t M,
@@ -503,9 +552,20 @@ joint_matrix_mad(
503
552
Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group> A,
504
553
joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group> B,
505
554
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> C) {
555
+ #ifdef __SYCL_DEVICE_ONLY__
556
+ #ifdef __NVPTX__
506
557
return sycl::ext::oneapi::detail::joint_matrix_mad_impl<
507
558
T1, T2, M, K, N, LayoutA, LayoutB, LayoutC>{}
508
559
.mad (A, B, C);
560
+ #endif
561
+ #else
562
+ (void )sg;
563
+ (void )A;
564
+ (void )B;
565
+ (void )C;
566
+ throw runtime_error (" joint_matrix_mad is not supported on host device." ,
567
+ PI_INVALID_DEVICE);
568
+ #endif // __SYCL_DEVICE_ONLY__*/
509
569
}
510
570
511
571
} // namespace experimental::matrix
0 commit comments