@@ -357,63 +357,59 @@ template <sycl::ext::oneapi::experimental::matrix::layout Layout, typename T,
357
357
size_t NumRows, size_t NumCols, access::address_space Space,
358
358
access::decorated IsDecorated>
359
359
void store_layoutT (
360
- joint_matrix_cuda<
360
+ const joint_matrix_cuda<
361
361
T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
362
362
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
363
363
multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
364
364
if constexpr (NumRows == 16 && NumCols == 16 ) {
365
365
if constexpr (std::is_same_v<T, float >) {
366
- __hmma_m16n16k16_st_c_f32 (dst.get (),
367
- reinterpret_cast <float *>(&src.wi_marray ),
368
- stride, get_layout_id<Layout>());
366
+ __hmma_m16n16k16_st_c_f32 (dst.get (), &src.wi_marray [0 ], stride,
367
+ get_layout_id<Layout>());
369
368
} else if constexpr (std::is_same_v<T, int32_t >) {
370
- __imma_m16n16k16_st_c_i32 (dst.get (),
371
- reinterpret_cast <int32_t *>(&src.wi_marray ),
372
- stride, get_layout_id<Layout>());
369
+ __imma_m16n16k16_st_c_i32 (dst.get (), &src.wi_marray [0 ], stride,
370
+ get_layout_id<Layout>());
373
371
} else if constexpr (std::is_same_v<T, half>) {
374
- __hmma_m16n16k16_st_c_f16 (reinterpret_cast <int32_t *>(dst.get ()),
375
- reinterpret_cast <int32_t *>(&src.wi_marray ),
376
- stride, get_layout_id<Layout>());
372
+ __hmma_m16n16k16_st_c_f16 (
373
+ reinterpret_cast <int32_t *>(dst.get ()),
374
+ reinterpret_cast <const int32_t *>(&src.wi_marray [0 ]), stride,
375
+ get_layout_id<Layout>());
377
376
}
378
377
} else if constexpr (NumRows == 8 && NumCols == 32 ) {
379
378
if constexpr (std::is_same_v<T, float >) {
380
- __hmma_m8n32k16_st_c_f32 (dst.get (),
381
- reinterpret_cast <float *>(&src.wi_marray ),
382
- stride, get_layout_id<Layout>());
379
+ __hmma_m8n32k16_st_c_f32 (dst.get (), &src.wi_marray [0 ], stride,
380
+ get_layout_id<Layout>());
383
381
} else if constexpr (std::is_same_v<T, int32_t >) {
384
- __imma_m8n32k16_st_c_i32 (dst.get (),
385
- reinterpret_cast <int32_t *>(&src.wi_marray ),
386
- stride, get_layout_id<Layout>());
382
+ __imma_m8n32k16_st_c_i32 (dst.get (), &src.wi_marray [0 ], stride,
383
+ get_layout_id<Layout>());
387
384
} else if constexpr (std::is_same_v<T, half>) {
388
- __hmma_m8n32k16_st_c_f16 (reinterpret_cast <int32_t *>(dst.get ()),
389
- reinterpret_cast <int32_t *>(&src.wi_marray ),
390
- stride, get_layout_id<Layout>());
385
+ __hmma_m8n32k16_st_c_f16 (
386
+ reinterpret_cast <int32_t *>(dst.get ()),
387
+ reinterpret_cast <const int32_t *>(&src.wi_marray [0 ]), stride,
388
+ get_layout_id<Layout>());
391
389
}
392
390
} else if constexpr (NumRows == 32 && NumCols == 8 ) {
393
391
if constexpr (std::is_same_v<T, float >) {
394
- __hmma_m32n8k16_st_c_f32 (dst.get (),
395
- reinterpret_cast <float *>(&src.wi_marray ),
396
- stride, get_layout_id<Layout>());
392
+ __hmma_m32n8k16_st_c_f32 (dst.get (), &src.wi_marray [0 ], stride,
393
+ get_layout_id<Layout>());
397
394
} else if constexpr (std::is_same_v<T, int32_t >) {
398
- __imma_m32n8k16_st_c_i32 (dst.get (),
399
- reinterpret_cast <int32_t *>(&src.wi_marray ),
400
- stride, get_layout_id<Layout>());
395
+ __imma_m32n8k16_st_c_i32 (dst.get (), &src.wi_marray [0 ], stride,
396
+ get_layout_id<Layout>());
401
397
} else if constexpr (std::is_same_v<T, half>) {
402
- __hmma_m32n8k16_st_c_f16 (reinterpret_cast <int32_t *>(dst.get ()),
403
- reinterpret_cast <int32_t *>(&src.wi_marray ),
404
- stride, get_layout_id<Layout>());
398
+ __hmma_m32n8k16_st_c_f16 (
399
+ reinterpret_cast <int32_t *>(dst.get ()),
400
+ reinterpret_cast <const int32_t *>(&src.wi_marray [0 ]), stride,
401
+ get_layout_id<Layout>());
405
402
}
406
403
} else if constexpr (std::is_same_v<T, double >) {
407
- __dmma_m8n8k4_st_c_f64 (dst.get (),
408
- reinterpret_cast <double *>(&src.wi_marray ), stride,
404
+ __dmma_m8n8k4_st_c_f64 (dst.get (), &src.wi_marray [0 ], stride,
409
405
get_layout_id<Layout>());
410
406
}
411
407
}
412
408
413
409
template <typename T, size_t NumRows, size_t NumCols,
414
410
access::address_space Space, access::decorated IsDecorated>
415
411
void joint_matrix_store_cuda (
416
- joint_matrix_cuda<
412
+ const joint_matrix_cuda<
417
413
T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
418
414
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
419
415
multi_ptr<T, Space, IsDecorated> dst, size_t stride,
@@ -465,8 +461,8 @@ constexpr int get_layout_pair_id<
465
461
}
466
462
467
463
template <
468
- typename Tm, typename Tc, std:: size_t M , std::size_t K , std::size_t N ,
469
- sycl::ext::oneapi::experimental::matrix::layout LayoutA,
464
+ typename Tm, typename Tc, typename Td , std::size_t M , std::size_t K ,
465
+ std:: size_t N, sycl::ext::oneapi::experimental::matrix::layout LayoutA,
470
466
sycl::ext::oneapi::experimental::matrix::layout LayoutB,
471
467
std::enable_if_t <
472
468
(LayoutA ==
@@ -480,13 +476,13 @@ template <
480
476
bool > = true >
481
477
void joint_matrix_mad_cuda (
482
478
joint_matrix_cuda<
483
- Tc , sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
479
+ Td , sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
484
480
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
485
- joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::a, M, K ,
486
- LayoutA> &A,
487
- joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::b, K, N ,
488
- LayoutB> &B,
489
- joint_matrix_cuda<
481
+ const joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::a,
482
+ M, K, LayoutA> &A,
483
+ const joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::b,
484
+ K, N, LayoutB> &B,
485
+ const joint_matrix_cuda<
490
486
Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
491
487
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) {
492
488
if constexpr (M == 16 && N == 16 && K == 16 ) {
@@ -506,16 +502,29 @@ void joint_matrix_mad_cuda(
506
502
auto ptrA = reinterpret_cast <const int32_t *>(&A.wi_marray );
507
503
auto ptrB = reinterpret_cast <const int32_t *>(&B.wi_marray );
508
504
if constexpr (std::is_same_v<Tc, float >) {
509
- __hmma_m16n16k16_mma_f32f32 (
510
- reinterpret_cast <float *>(&D.wi_marray ), ptrA, ptrB,
511
- reinterpret_cast <const float *>(&C.wi_marray ),
512
- get_layout_pair_id<LayoutA, LayoutB>(), 0 );
513
-
505
+ if constexpr (std::is_same<Td, float >::value) {
506
+ __hmma_m16n16k16_mma_f32f32 (
507
+ reinterpret_cast <float *>(&D.wi_marray ), ptrA, ptrB,
508
+ reinterpret_cast <const float *>(&C.wi_marray ),
509
+ get_layout_pair_id<LayoutA, LayoutB>(), 0 );
510
+ } else {
511
+ __hmma_m16n16k16_mma_f16f32 (
512
+ reinterpret_cast <int32_t *>(&D.wi_marray ), ptrA, ptrB,
513
+ reinterpret_cast <const float *>(&C.wi_marray ),
514
+ get_layout_pair_id<LayoutA, LayoutB>(), 0 );
515
+ }
514
516
} else if constexpr (std::is_same_v<Tc, half>) {
515
- __hmma_m16n16k16_mma_f16f16 (
516
- reinterpret_cast <int32_t *>(&D.wi_marray ), ptrA, ptrB,
517
- reinterpret_cast <const int32_t *>(&C.wi_marray ),
518
- get_layout_pair_id<LayoutA, LayoutB>(), 0 );
517
+ if constexpr (std::is_same<Td, float >::value) {
518
+ __hmma_m16n16k16_mma_f32f16 (
519
+ reinterpret_cast <float *>(&D.wi_marray ), ptrA, ptrB,
520
+ reinterpret_cast <const int32_t *>(&C.wi_marray ),
521
+ get_layout_pair_id<LayoutA, LayoutB>(), 0 );
522
+ } else {
523
+ __hmma_m16n16k16_mma_f16f16 (
524
+ reinterpret_cast <int32_t *>(&D.wi_marray ), ptrA, ptrB,
525
+ reinterpret_cast <const int32_t *>(&C.wi_marray ),
526
+ get_layout_pair_id<LayoutA, LayoutB>(), 0 );
527
+ }
519
528
}
520
529
} else if constexpr (std::is_same_v<Tm, sycl::ext::oneapi::bfloat16>) {
521
530
__mma_bf16_m16n16k16_mma_f32 (
@@ -542,15 +551,29 @@ void joint_matrix_mad_cuda(
542
551
auto ptrA = reinterpret_cast <const int32_t *>(&A.wi_marray );
543
552
auto ptrB = reinterpret_cast <const int32_t *>(&B.wi_marray );
544
553
if constexpr (std::is_same_v<Tc, float >) {
545
- __hmma_m8n32k16_mma_f32f32 (
546
- reinterpret_cast <float *>(&D.wi_marray ), ptrA, ptrB,
547
- reinterpret_cast <const float *>(&C.wi_marray ),
548
- get_layout_pair_id<LayoutA, LayoutB>(), 0 );
554
+ if constexpr (std::is_same<Td, float >::value) {
555
+ __hmma_m8n32k16_mma_f32f32 (
556
+ reinterpret_cast <float *>(&D.wi_marray ), ptrA, ptrB,
557
+ reinterpret_cast <const float *>(&C.wi_marray ),
558
+ get_layout_pair_id<LayoutA, LayoutB>(), 0 );
559
+ } else {
560
+ __hmma_m8n32k16_mma_f16f32 (
561
+ reinterpret_cast <int32_t *>(&D.wi_marray ), ptrA, ptrB,
562
+ reinterpret_cast <const float *>(&C.wi_marray ),
563
+ get_layout_pair_id<LayoutA, LayoutB>(), 0 );
564
+ }
549
565
} else if constexpr (std::is_same_v<Tc, half>) {
550
- __hmma_m8n32k16_mma_f16f16 (
551
- reinterpret_cast <int32_t *>(&D.wi_marray ), ptrA, ptrB,
552
- reinterpret_cast <const int32_t *>(&C.wi_marray ),
553
- get_layout_pair_id<LayoutA, LayoutB>(), 0 );
566
+ if constexpr (std::is_same<Td, float >::value) {
567
+ __hmma_m8n32k16_mma_f32f16 (
568
+ reinterpret_cast <float *>(&D.wi_marray ), ptrA, ptrB,
569
+ reinterpret_cast <const int32_t *>(&C.wi_marray ),
570
+ get_layout_pair_id<LayoutA, LayoutB>(), 0 );
571
+ } else {
572
+ __hmma_m8n32k16_mma_f16f16 (
573
+ reinterpret_cast <int32_t *>(&D.wi_marray ), ptrA, ptrB,
574
+ reinterpret_cast <const int32_t *>(&C.wi_marray ),
575
+ get_layout_pair_id<LayoutA, LayoutB>(), 0 );
576
+ }
554
577
}
555
578
} else if constexpr (std::is_same_v<Tm, sycl::ext::oneapi::bfloat16>) {
556
579
__mma_bf16_m8n32k16_mma_f32 (
@@ -581,25 +604,40 @@ void joint_matrix_mad_cuda(
581
604
reinterpret_cast <const float *>(&C.wi_marray ),
582
605
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
583
606
} else if constexpr (std::is_same_v<Tm, half>) {
607
+
584
608
auto ptrA = reinterpret_cast <const int32_t *>(&A.wi_marray );
585
609
auto ptrB = reinterpret_cast <const int32_t *>(&B.wi_marray );
586
610
if constexpr (std::is_same_v<Tc, float >) {
587
- __hmma_m32n8k16_mma_f32f32 (
588
- reinterpret_cast <float *>(&D.wi_marray ), ptrA, ptrB,
589
- reinterpret_cast <const float *>(&C.wi_marray ),
590
- get_layout_pair_id<LayoutA, LayoutB>(), 0 );
611
+ if constexpr (std::is_same<Td, float >::value) {
612
+ __hmma_m32n8k16_mma_f32f32 (
613
+ reinterpret_cast <float *>(&D.wi_marray ), ptrA, ptrB,
614
+ reinterpret_cast <const float *>(&C.wi_marray ),
615
+ get_layout_pair_id<LayoutA, LayoutB>(), 0 );
616
+ } else {
617
+ __hmma_m32n8k16_mma_f16f32 (
618
+ reinterpret_cast <int32_t *>(&D.wi_marray ), ptrA, ptrB,
619
+ reinterpret_cast <const float *>(&C.wi_marray ),
620
+ get_layout_pair_id<LayoutA, LayoutB>(), 0 );
621
+ }
591
622
} else if constexpr (std::is_same_v<Tc, half>) {
592
- __hmma_m32n8k16_mma_f16f16 (
593
- reinterpret_cast <int32_t *>(&D.wi_marray ), ptrA, ptrB,
594
- reinterpret_cast <const int32_t *>(&C.wi_marray ),
595
- get_layout_pair_id<LayoutA, LayoutB>(), 0 );
623
+ if constexpr (std::is_same<Td, float >::value) {
624
+ __hmma_m32n8k16_mma_f32f16 (
625
+ reinterpret_cast <float *>(&D.wi_marray ), ptrA, ptrB,
626
+ reinterpret_cast <const int32_t *>(&C.wi_marray ),
627
+ get_layout_pair_id<LayoutA, LayoutB>(), 0 );
628
+ } else {
629
+ __hmma_m32n8k16_mma_f16f16 (
630
+ reinterpret_cast <int32_t *>(&D.wi_marray ), ptrA, ptrB,
631
+ reinterpret_cast <const int32_t *>(&C.wi_marray ),
632
+ get_layout_pair_id<LayoutA, LayoutB>(), 0 );
633
+ }
596
634
}
597
635
}
598
636
} else if constexpr (M == 16 && N == 16 && K == 8 ) {
599
637
__mma_tf32_m16n16k8_mma_f32 (reinterpret_cast <float *>(&D.wi_marray ),
600
- reinterpret_cast <int32_t *>(&A.wi_marray ),
601
- reinterpret_cast <int32_t *>(&B.wi_marray ),
602
- reinterpret_cast <float *>(&C.wi_marray ),
638
+ reinterpret_cast <const int32_t *>(&A.wi_marray ),
639
+ reinterpret_cast <const int32_t *>(&B.wi_marray ),
640
+ reinterpret_cast <const float *>(&C.wi_marray ),
603
641
get_layout_pair_id<LayoutA, LayoutB>(), 0 );
604
642
} else if constexpr (std::is_same_v<Tm, double >) {
605
643
__dmma_m8n8k4_mma_f64 (reinterpret_cast <double *>(&D.wi_marray ),
0 commit comments