Skip to content

Commit 4d756a4

Browse files
committed
Added T1=T2=half cases.
Constrained joint_matrix_XX_impl functions to take a joint_matrix constructed from sycl::sub_group template parameter. Signed-off-by: jack.kirk <[email protected]>
1 parent 7eb02c2 commit 4d756a4

File tree

3 files changed

+289
-38
lines changed

3 files changed

+289
-38
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp

Lines changed: 98 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,22 @@ __SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 32, int32_t, 8)
3838
__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 8, 16, int32_t, 8)
3939
__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 32, int32_t, 8)
4040
__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 8, 32, float, 8)
41+
__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 8, 32, int32_t, 4)
4142

4243
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 8, 16, int32_t, 1)
4344
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 32, int32_t, 4)
4445
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 8, 16, int32_t, 1)
4546
__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 32, int32_t, 4)
4647
__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 8, 32, int32_t, 8)
4748

49+
4850
// m32n8k16
4951
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 32, 16, int32_t, 8)
5052
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 8, int32_t, 2)
5153
__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 32, 16, int32_t, 8)
5254
__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 8, int32_t, 8)
5355
__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 32, 8, float, 8)
56+
__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 32, 8, int32_t, 4)
5457

5558
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 32, 16, int32_t, 4)
5659
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 8, int32_t, 1)
@@ -64,6 +67,7 @@ __SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 16, int32_t, 8)
6467
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 16, 16, int32_t, 4)
6568
__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 16, int32_t, 4)
6669
__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 16, 16, float, 8)
70+
__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 16, 16, int32_t, 4)
6771

6872
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 16, 16, int32_t, 2)
6973
__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 16, int32_t, 2)
@@ -82,7 +86,7 @@ template <typename T, sycl::ext::oneapi::experimental::matrix::matrix_use Use,
8286
access::address_space Space, typename Cond = void>
8387
struct joint_matrix_load_impl {
8488
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
85-
T, Use, NumRows, NumCols, Layout> &res,
89+
T, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
8690
multi_ptr<T, Space> src, size_t stride);
8791
};
8892

@@ -112,11 +116,8 @@ struct joint_matrix_load_impl<
112116
Layout == sycl::ext::oneapi::experimental::
113117
matrix::matrix_layout::col_major>> {
114118
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
115-
T, Use, NumRows, NumCols, Layout> &res,
119+
T, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
116120
multi_ptr<T, Space> src, size_t stride) {
117-
#ifdef __NVPTX__
118-
#ifdef __SYCL_DEVICE_ONLY__
119-
120121
if constexpr (std::is_same<T, uint16_t>::value) {
121122
int32_t *tileptr = reinterpret_cast<int32_t *>(src.get());
122123
if constexpr (NumRows == 16 && NumCols == 16) {
@@ -203,6 +204,10 @@ struct joint_matrix_load_impl<
203204
matrix_use::b) {
204205
__hmma_m16n16k16_ld_b(res.data, tileptr, stride,
205206
get_layout_id<Layout>());
207+
} else if constexpr (Use == sycl::ext::oneapi::experimental::matrix::
208+
matrix_use::accumulator) {
209+
__hmma_m16n16k16_ld_c_f16(res.data, tileptr, stride,
210+
get_layout_id<Layout>());
206211
}
207212
} else if constexpr (NumRows == 8 && NumCols == 16) {
208213
__hmma_m8n32k16_ld_a(res.data, tileptr, stride,
@@ -216,7 +221,14 @@ struct joint_matrix_load_impl<
216221
} else if constexpr (NumRows == 16 && NumCols == 8) {
217222
__hmma_m32n8k16_ld_b(res.data, tileptr, stride,
218223
get_layout_id<Layout>());
224+
} else if constexpr (NumRows == 32 && NumCols == 8) {
225+
__hmma_m32n8k16_ld_c_f16(res.data, tileptr, stride,
226+
get_layout_id<Layout>());
227+
} else if constexpr (NumRows == 8 && NumCols == 32) {
228+
__hmma_m8n32k16_ld_c_f16(res.data, tileptr, stride,
229+
get_layout_id<Layout>());
219230
}
231+
220232
} else if constexpr (std::is_same<T, int32_t>::value) {
221233
if constexpr (NumRows == 16 && NumCols == 16) {
222234
__imma_m16n16k16_ld_c(res.data, src.get(), stride,
@@ -254,8 +266,6 @@ struct joint_matrix_load_impl<
254266
get_layout_id<Layout>());
255267
}
256268
}
257-
#endif
258-
#endif
259269
}
260270
};
261271

@@ -266,7 +276,7 @@ struct joint_matrix_store_impl {
266276
void
267277
store(sycl::ext::oneapi::experimental::matrix::joint_matrix<
268278
T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
269-
NumRows, NumCols, Layout> &src,
279+
NumRows, NumCols, Layout, sycl::sub_group> &src,
270280
multi_ptr<T, Space> dst, size_t stride);
271281
};
272282

@@ -282,18 +292,19 @@ struct joint_matrix_store_impl<
282292
void
283293
store(sycl::ext::oneapi::experimental::matrix::joint_matrix<
284294
T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
285-
NumRows, NumCols, Layout> &src,
295+
NumRows, NumCols, Layout, sycl::sub_group> &src,
286296
multi_ptr<T, Space> dst, size_t stride) {
287-
288-
#ifdef __NVPTX__
289-
#ifdef __SYCL_DEVICE_ONLY__
290297
if (NumRows == 16 && NumCols == 16) {
291298
if constexpr (std::is_same<T, float>::value) {
292299
__hmma_m16n16k16_st_c_f32(dst.get(), src.data, stride,
293300
get_layout_id<Layout>());
294301
} else if constexpr (std::is_same<T, int32_t>::value) {
295302
__imma_m16n16k16_st_c_i32(dst.get(), src.data, stride,
296303
get_layout_id<Layout>());
304+
} else if constexpr (std::is_same<T, half>::value) {
305+
int32_t *tileptr = reinterpret_cast<int32_t *>(dst.get());
306+
__hmma_m16n16k16_st_c_f16(tileptr, src.data, stride,
307+
get_layout_id<Layout>());
297308
}
298309
} else if (NumRows == 8 && NumCols == 32) {
299310
if constexpr (std::is_same<T, float>::value) {
@@ -302,6 +313,10 @@ struct joint_matrix_store_impl<
302313
} else if constexpr (std::is_same<T, int32_t>::value) {
303314
__imma_m8n32k16_st_c_i32(dst.get(), src.data, stride,
304315
get_layout_id<Layout>());
316+
} else if constexpr (std::is_same<T, half>::value) {
317+
int32_t *tileptr = reinterpret_cast<int32_t *>(dst.get());
318+
__hmma_m8n32k16_st_c_f16(tileptr, src.data, stride,
319+
get_layout_id<Layout>());
305320
}
306321
} else if (NumRows == 32 && NumCols == 8) {
307322
if constexpr (std::is_same<T, float>::value) {
@@ -310,14 +325,15 @@ struct joint_matrix_store_impl<
310325
} else if constexpr (std::is_same<T, int32_t>::value) {
311326
__imma_m32n8k16_st_c_i32(dst.get(), src.data, stride,
312327
get_layout_id<Layout>());
328+
} else if constexpr (std::is_same<T, half>::value) {
329+
int32_t *tileptr = reinterpret_cast<int32_t *>(dst.get());
330+
__hmma_m32n8k16_st_c_f16(tileptr, src.data, stride,
331+
get_layout_id<Layout>());
313332
}
314333
} else if constexpr (std::is_same<T, double>::value) {
315334
__dmma_m8n8k4_st_c_f64(dst.get(), src.data, stride,
316335
get_layout_id<Layout>());
317336
}
318-
319-
#endif
320-
#endif
321337
}
322338
};
323339

@@ -329,18 +345,18 @@ template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
329345
struct joint_matrix_mad_impl {
330346
sycl::ext::oneapi::experimental::matrix::joint_matrix<
331347
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
332-
N, LayoutC>
348+
N, LayoutC, sycl::sub_group>
333349
mad(sycl::ext::oneapi::experimental::matrix::joint_matrix<
334350
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
335-
LayoutA>
351+
LayoutA, sycl::sub_group>
336352
A,
337353
sycl::ext::oneapi::experimental::matrix::joint_matrix<
338354
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
339-
LayoutB>
355+
LayoutB, sycl::sub_group>
340356
B,
341357
sycl::ext::oneapi::experimental::matrix::joint_matrix<
342358
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
343-
M, N, LayoutC>
359+
M, N, LayoutC, sycl::sub_group>
344360
C);
345361
};
346362

@@ -397,26 +413,23 @@ struct joint_matrix_mad_impl<
397413
col_major)>> {
398414
sycl::ext::oneapi::experimental::matrix::joint_matrix<
399415
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
400-
N, LayoutC>
416+
N, LayoutC, sycl::sub_group>
401417
mad(sycl::ext::oneapi::experimental::matrix::joint_matrix<
402418
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K,
403-
LayoutA>
419+
LayoutA, sycl::sub_group>
404420
A,
405421
sycl::ext::oneapi::experimental::matrix::joint_matrix<
406422
T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N,
407-
LayoutB>
423+
LayoutB, sycl::sub_group>
408424
B,
409425
sycl::ext::oneapi::experimental::matrix::joint_matrix<
410426
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator,
411-
M, N, LayoutC>
427+
M, N, LayoutC, sycl::sub_group>
412428
C) {
413429
sycl::ext::oneapi::experimental::matrix::joint_matrix<
414430
T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M,
415-
N, LayoutC>
431+
N, LayoutC, sycl::sub_group>
416432
D;
417-
418-
#ifdef __NVPTX__
419-
#ifdef __SYCL_DEVICE_ONLY__
420433
if constexpr (M == 16 && N == 16 && K == 16) {
421434
if constexpr (std::is_same<T1, int8_t>::value) {
422435
__imma_m16n16k16_mma_s8(D.data, A.data, B.data, C.data,
@@ -425,8 +438,15 @@ struct joint_matrix_mad_impl<
425438
__imma_m16n16k16_mma_u8(D.data, A.data, B.data, C.data,
426439
get_layout_pair_id<LayoutA, LayoutB>(), 0);
427440
} else if constexpr (std::is_same<T1, half>::value) {
428-
__hmma_m16n16k16_mma_f32f32(D.data, A.data, B.data, C.data,
429-
get_layout_pair_id<LayoutA, LayoutB>(), 0);
441+
if constexpr (std::is_same<T2, float>::value) {
442+
__hmma_m16n16k16_mma_f32f32(D.data, A.data, B.data, C.data,
443+
get_layout_pair_id<LayoutA, LayoutB>(),
444+
0);
445+
} else if constexpr (std::is_same<T2, half>::value) {
446+
__hmma_m16n16k16_mma_f16f16(D.data, A.data, B.data, C.data,
447+
get_layout_pair_id<LayoutA, LayoutB>(),
448+
0);
449+
}
430450
} else if constexpr (std::is_same<T1, uint16_t>::value) {
431451
__mma_bf16_m16n16k16_mma_f32(D.data, A.data, B.data, C.data,
432452
get_layout_pair_id<LayoutA, LayoutB>(), 0);
@@ -439,8 +459,13 @@ struct joint_matrix_mad_impl<
439459
__imma_m8n32k16_mma_u8(D.data, A.data, B.data, C.data,
440460
get_layout_pair_id<LayoutA, LayoutB>(), 0);
441461
} else if constexpr (std::is_same<T1, half>::value) {
442-
__hmma_m8n32k16_mma_f32f32(D.data, A.data, B.data, C.data,
443-
get_layout_pair_id<LayoutA, LayoutB>(), 0);
462+
if constexpr (std::is_same<T2, float>::value) {
463+
__hmma_m8n32k16_mma_f32f32(D.data, A.data, B.data, C.data,
464+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
465+
} else if constexpr (std::is_same<T2, half>::value) {
466+
__hmma_m8n32k16_mma_f16f16(D.data, A.data, B.data, C.data,
467+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
468+
}
444469
} else if constexpr (std::is_same<T1, uint16_t>::value) {
445470
__mma_bf16_m8n32k16_mma_f32(D.data, A.data, B.data, C.data,
446471
get_layout_pair_id<LayoutA, LayoutB>(), 0);
@@ -456,16 +481,18 @@ struct joint_matrix_mad_impl<
456481
__mma_bf16_m32n8k16_mma_f32(D.data, A.data, B.data, C.data,
457482
get_layout_pair_id<LayoutA, LayoutB>(), 0);
458483
} else if constexpr (std::is_same<T1, half>::value) {
459-
__hmma_m32n8k16_mma_f32f32(D.data, A.data, B.data, C.data,
460-
get_layout_pair_id<LayoutA, LayoutB>(), 0);
484+
if constexpr (std::is_same<T2, float>::value) {
485+
__hmma_m32n8k16_mma_f32f32(D.data, A.data, B.data, C.data,
486+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
487+
} else if constexpr (std::is_same<T2, half>::value) {
488+
__hmma_m32n8k16_mma_f16f16(D.data, A.data, B.data, C.data,
489+
get_layout_pair_id<LayoutA, LayoutB>(), 0);
490+
}
461491
}
462492
} else if constexpr (std::is_same<T1, double>::value) {
463493
__dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data,
464494
get_layout_pair_id<LayoutA, LayoutB>(), 0);
465495
}
466-
#endif
467-
#endif
468-
469496
return D;
470497
}
471498
};
@@ -479,9 +506,20 @@ template <typename Group, typename T, matrix_use Use, size_t NumRows,
479506
void joint_matrix_load(
480507
Group sg, joint_matrix<T, Use, NumRows, NumCols, Layout, Group> &res,
481508
multi_ptr<T, Space> src, size_t stride) {
482-
sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols,
509+
#ifdef __SYCL_DEVICE_ONLY__
510+
#ifdef __NVPTX__
511+
sycl::ext::oneapi::detail::joint_matrix_load_impl<T, Use, NumRows, NumCols,
483512
Layout, Space>{}
484513
.load(res, src, stride);
514+
#endif
515+
#else
516+
(void)sg;
517+
(void)res;
518+
(void)src;
519+
(void)stride;
520+
throw runtime_error("joint_matrix_load is not supported on host device.",
521+
PI_INVALID_DEVICE);
522+
#endif // __SYCL_DEVICE_ONLY__*/
485523
}
486524

487525
template <typename Group, typename T, size_t NumRows, size_t NumCols,
@@ -490,9 +528,20 @@ void joint_matrix_store(Group sg,
490528
joint_matrix<T, matrix_use::accumulator, NumRows,
491529
NumCols, Layout, Group> &src,
492530
multi_ptr<T, Space> dst, size_t stride) {
493-
sycl::ext::oneapi::detail::joint_matrix_store_impl<T, NumRows, NumCols,
531+
#ifdef __SYCL_DEVICE_ONLY__
532+
#ifdef __NVPTX__
533+
sycl::ext::oneapi::detail::joint_matrix_store_impl<T, NumRows, NumCols,
494534
Layout, Space>{}
495535
.store(src, dst, stride);
536+
#endif
537+
#else
538+
(void)sg;
539+
(void)src;
540+
(void)dst;
541+
(void)stride;
542+
throw runtime_error("joint_matrix_store is not supported on host device.",
543+
PI_INVALID_DEVICE);
544+
#endif // __SYCL_DEVICE_ONLY__*/
496545
}
497546

498547
template <typename Group, typename T1, typename T2, std::size_t M,
@@ -503,9 +552,20 @@ joint_matrix_mad(
503552
Group sg, joint_matrix<T1, matrix_use::a, M, K, LayoutA, Group> A,
504553
joint_matrix<T1, matrix_use::b, K, N, LayoutB, Group> B,
505554
joint_matrix<T2, matrix_use::accumulator, M, N, LayoutC, Group> C) {
555+
#ifdef __SYCL_DEVICE_ONLY__
556+
#ifdef __NVPTX__
506557
return sycl::ext::oneapi::detail::joint_matrix_mad_impl<
507558
T1, T2, M, K, N, LayoutA, LayoutB, LayoutC>{}
508559
.mad(A, B, C);
560+
#endif
561+
#else
562+
(void)sg;
563+
(void)A;
564+
(void)B;
565+
(void)C;
566+
throw runtime_error("joint_matrix_mad is not supported on host device.",
567+
PI_INVALID_DEVICE);
568+
#endif // __SYCL_DEVICE_ONLY__*/
509569
}
510570

511571
} // namespace experimental::matrix

0 commit comments

Comments
 (0)