@@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
82
82
return device_type.str ();
83
83
}
84
84
85
+ template <typename Ts> struct matrix_info_t {
86
+ oneapi::mkl::transpose transpose_info[2 ];
87
+ Ts value_info[2 ];
88
+ std::int64_t size_info[3 ];
89
+ std::int64_t ld_info[3 ];
90
+ std::int64_t groupsize_info;
91
+ };
92
+
85
93
namespace dpct
86
94
{
87
95
typedef sycl::queue *queue_ptr;
@@ -1737,26 +1745,13 @@ namespace dpct
1737
1745
};
1738
1746
1739
1747
template <class Ta , class Tb , class Tc , class Ts >
1740
- inline void gemm_batch_impl (sycl::queue &q, oneapi::mkl::transpose a_trans,
1741
- oneapi::mkl::transpose b_trans, int m, int n, int k,
1742
- const void *alpha, const void **a, int lda,
1743
- const void **b, int ldb, const void *beta, void **c,
1744
- int ldc, int batch_size)
1745
- {
1746
- struct matrix_info_t
1747
- {
1748
- oneapi::mkl::transpose transpose_info[2 ];
1749
- Ts value_info[2 ];
1750
- std::int64_t size_info[3 ];
1751
- std::int64_t ld_info[3 ];
1752
- std::int64_t groupsize_info;
1753
- };
1754
-
1748
+ inline void gemm_batch_impl (sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1749
+ int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
1750
+ int ldb, const void * beta, void ** c, int ldc, int batch_size,
1751
+ matrix_info_t <float > * matrix_info) {
1755
1752
Ts alpha_value = dpct::get_value (reinterpret_cast <const Ts *>(alpha), q);
1756
1753
Ts beta_value = dpct::get_value (reinterpret_cast <const Ts *>(beta), q);
1757
1754
1758
- matrix_info_t *matrix_info =
1759
- (matrix_info_t *)std::malloc (sizeof (matrix_info_t ));
1760
1755
matrix_info->transpose_info [0 ] = a_trans;
1761
1756
matrix_info->transpose_info [1 ] = b_trans;
1762
1757
matrix_info->value_info [0 ] = alpha_value;
@@ -1773,23 +1768,18 @@ namespace dpct
1773
1768
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
1774
1769
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info ,
1775
1770
matrix_info->transpose_info + 1 , matrix_info->size_info , matrix_info->size_info + 1 ,
1776
- matrix_info->size_info + 2 , matrix_info-> value_info , reinterpret_cast <const Ta **>(a ),
1777
- matrix_info-> ld_info , reinterpret_cast <const Tb **>(b ), matrix_info->ld_info + 1 ,
1778
- matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 ,
1779
- &(matrix_info->groupsize_info ));
1771
+ matrix_info->size_info + 2 , reinterpret_cast <Ts *>(matrix_info-> value_info ),
1772
+ reinterpret_cast <const Ta **>(a ), matrix_info->ld_info , reinterpret_cast < const Tb **>(b) ,
1773
+ matrix_info->ld_info + 1 , reinterpret_cast <Ts *>( matrix_info->value_info + 1 ) ,
1774
+ reinterpret_cast <Tc **>(c), matrix_info-> ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
1780
1775
#else
1781
1776
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
1782
1777
q, matrix_info->transpose_info , matrix_info->transpose_info + 1 , matrix_info->size_info ,
1783
- matrix_info->size_info + 1 , matrix_info->size_info + 2 , matrix_info->value_info ,
1778
+ matrix_info->size_info + 1 , matrix_info->size_info + 2 , reinterpret_cast <Ts *>( matrix_info->value_info ) ,
1784
1779
reinterpret_cast <const Ta **>(a), matrix_info->ld_info , reinterpret_cast <const Tb **>(b),
1785
- matrix_info->ld_info + 1 , matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c ),
1786
- matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
1780
+ matrix_info->ld_info + 1 , reinterpret_cast <Ts *>( matrix_info->value_info + 1 ),
1781
+ reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
1787
1782
#endif
1788
-
1789
- q.submit ([&](sycl::handler &cgh)
1790
- {
1791
- cgh.depends_on (e);
1792
- cgh.host_task ([=] { std::free (matrix_info); }); });
1793
1783
}
1794
1784
1795
1785
template <class Ta , class Tb , class Tc , class Ts >
@@ -2427,25 +2417,11 @@ namespace dpct
2427
2417
// / \param [in] ldc Leading dimension of C.
2428
2418
// / \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2429
2419
// / \param [in] scaling_type Data type of the scaling factors.
2430
- inline void gemm_batch (sycl::queue &q, oneapi::mkl::transpose a_trans,
2431
- oneapi::mkl::transpose b_trans, int m, int n, int k,
2432
- const void *alpha, const void *a[],
2433
- library_data_t a_type, int lda, const void *b[],
2434
- library_data_t b_type, int ldb, const void *beta,
2435
- void *c[], library_data_t c_type, int ldc,
2436
- int batch_size, library_data_t scaling_type)
2437
- {
2438
- if (scaling_type == library_data_t ::real_float &&
2439
- c_type == library_data_t ::complex_float)
2440
- {
2441
- scaling_type = library_data_t ::complex_float;
2442
- }
2443
- else if (scaling_type == library_data_t ::real_double &&
2444
- c_type == library_data_t ::complex_double)
2445
- {
2446
- scaling_type = library_data_t ::complex_double;
2447
- }
2448
-
2420
+ inline void gemm_batch (sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2421
+ int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
2422
+ const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
2423
+ library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
2424
+ matrix_info_t <float > * matrix_info) {
2449
2425
std::uint64_t key =
2450
2426
detail::get_type_combination_id (a_type, b_type, c_type, scaling_type);
2451
2427
switch (key)
@@ -2454,68 +2430,41 @@ namespace dpct
2454
2430
library_data_t ::real_float, library_data_t ::real_float,
2455
2431
library_data_t ::real_float, library_data_t ::real_float):
2456
2432
{
2457
- detail::gemm_batch_impl<float , float , float , float >(
2458
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2459
- batch_size);
2433
+ detail::gemm_batch_impl<float , float , float , float >(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2434
+ beta, c, ldc, batch_size, matrix_info);
2460
2435
break ;
2461
2436
}
2462
2437
case detail::get_type_combination_id (
2463
2438
library_data_t ::real_double, library_data_t ::real_double,
2464
2439
library_data_t ::real_double, library_data_t ::real_double):
2465
2440
{
2466
- detail::gemm_batch_impl<double , double , double , double >(
2467
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2468
- batch_size);
2469
- break ;
2470
- }
2471
- case detail::get_type_combination_id (
2472
- library_data_t ::complex_float, library_data_t ::complex_float,
2473
- library_data_t ::complex_float, library_data_t ::complex_float):
2474
- {
2475
- detail::gemm_batch_impl<std::complex<float >, std::complex<float >,
2476
- std::complex<float >, std::complex<float >>(
2477
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2478
- batch_size);
2479
- break ;
2480
- }
2481
- case detail::get_type_combination_id (
2482
- library_data_t ::complex_double, library_data_t ::complex_double,
2483
- library_data_t ::complex_double, library_data_t ::complex_double):
2484
- {
2485
- detail::gemm_batch_impl<std::complex<double >, std::complex<double >,
2486
- std::complex<double >, std::complex<double >>(
2487
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2488
- batch_size);
2441
+ detail::gemm_batch_impl<double , double , double , double >(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2442
+ beta, c, ldc, batch_size, matrix_info);
2489
2443
break ;
2490
2444
}
2491
2445
case detail::get_type_combination_id (
2492
2446
library_data_t ::real_half, library_data_t ::real_half,
2493
2447
library_data_t ::real_half, library_data_t ::real_half):
2494
2448
{
2495
- detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
2496
- sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
2497
- a, lda, b, ldb, beta, c, ldc,
2498
- batch_size);
2449
+ detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2450
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2499
2451
break ;
2500
2452
}
2501
2453
#ifdef __INTEL_MKL__
2502
2454
case detail::get_type_combination_id (
2503
2455
library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
2504
2456
library_data_t ::real_bfloat16, library_data_t ::real_float):
2505
2457
{
2506
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
2507
- oneapi::mkl::bfloat16, float >(
2508
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2509
- batch_size);
2458
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float >(
2459
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2510
2460
break ;
2511
2461
}
2512
2462
case detail::get_type_combination_id (
2513
2463
library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
2514
2464
library_data_t ::real_float, library_data_t ::real_float):
2515
2465
{
2516
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float ,
2517
- float >(q, a_trans, b_trans, m, n, k, alpha, a, lda,
2518
- b, ldb, beta, c, ldc, batch_size);
2466
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float , float >(
2467
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2519
2468
break ;
2520
2469
}
2521
2470
case detail::get_type_combination_id (
@@ -2526,28 +2475,25 @@ namespace dpct
2526
2475
dpct::get_value (reinterpret_cast <const std::int32_t *>(alpha), q);
2527
2476
float beta_float =
2528
2477
dpct::get_value (reinterpret_cast <const std::int32_t *>(beta), q);
2529
- detail::gemm_batch_impl<std::int8_t , std::int8_t , std::int32_t ,
2530
- float >(q, a_trans, b_trans, m, n, k, &alpha_float,
2531
- a, lda, b, ldb, &beta_float, c, ldc,
2532
- batch_size);
2478
+ detail::gemm_batch_impl<std::int8_t , std::int8_t , std::int32_t , float >(
2479
+ q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
2480
+ matrix_info);
2533
2481
break ;
2534
2482
}
2535
2483
case detail::get_type_combination_id (
2536
2484
library_data_t ::real_int8, library_data_t ::real_int8,
2537
2485
library_data_t ::real_float, library_data_t ::real_float):
2538
2486
{
2539
2487
detail::gemm_batch_impl<std::int8_t , std::int8_t , float , float >(
2540
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2541
- batch_size);
2488
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2542
2489
break ;
2543
2490
}
2544
2491
case detail::get_type_combination_id (
2545
2492
library_data_t ::real_half, library_data_t ::real_half,
2546
2493
library_data_t ::real_float, library_data_t ::real_float):
2547
2494
{
2548
2495
detail::gemm_batch_impl<sycl::half, sycl::half, float , float >(
2549
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2550
- batch_size);
2496
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2551
2497
break ;
2552
2498
}
2553
2499
#endif
@@ -2562,8 +2508,7 @@ namespace dpct
2562
2508
sycl::half alpha_half (alpha_value);
2563
2509
sycl::half beta_half (beta_value);
2564
2510
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2565
- q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
2566
- batch_size);
2511
+ q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
2567
2512
break ;
2568
2513
}
2569
2514
default :
0 commit comments