@@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
82
82
return device_type.str ();
83
83
}
84
84
85
+ template <typename Ts> struct matrix_info_t {
86
+ oneapi::mkl::transpose transpose_info[2 ];
87
+ Ts value_info[2 ];
88
+ std::int64_t size_info[3 ];
89
+ std::int64_t ld_info[3 ];
90
+ std::int64_t groupsize_info;
91
+ };
92
+
85
93
namespace dpct
86
94
{
87
95
typedef sycl::queue *queue_ptr;
@@ -1727,26 +1735,13 @@ namespace dpct
1727
1735
};
1728
1736
1729
1737
template <class Ta , class Tb , class Tc , class Ts >
1730
- inline void gemm_batch_impl (sycl::queue &q, oneapi::mkl::transpose a_trans,
1731
- oneapi::mkl::transpose b_trans, int m, int n, int k,
1732
- const void *alpha, const void **a, int lda,
1733
- const void **b, int ldb, const void *beta, void **c,
1734
- int ldc, int batch_size)
1735
- {
1736
- struct matrix_info_t
1737
- {
1738
- oneapi::mkl::transpose transpose_info[2 ];
1739
- Ts value_info[2 ];
1740
- std::int64_t size_info[3 ];
1741
- std::int64_t ld_info[3 ];
1742
- std::int64_t groupsize_info;
1743
- };
1744
-
1738
+ inline void gemm_batch_impl (sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1739
+ int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
1740
+ int ldb, const void * beta, void ** c, int ldc, int batch_size,
1741
+ matrix_info_t <float > * matrix_info) {
1745
1742
Ts alpha_value = dpct::get_value (reinterpret_cast <const Ts *>(alpha), q);
1746
1743
Ts beta_value = dpct::get_value (reinterpret_cast <const Ts *>(beta), q);
1747
1744
1748
- matrix_info_t *matrix_info =
1749
- (matrix_info_t *)std::malloc (sizeof (matrix_info_t ));
1750
1745
matrix_info->transpose_info [0 ] = a_trans;
1751
1746
matrix_info->transpose_info [1 ] = b_trans;
1752
1747
matrix_info->value_info [0 ] = alpha_value;
@@ -1763,23 +1758,18 @@ namespace dpct
1763
1758
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
1764
1759
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info ,
1765
1760
matrix_info->transpose_info + 1 , matrix_info->size_info , matrix_info->size_info + 1 ,
1766
- matrix_info->size_info + 2 , matrix_info-> value_info , reinterpret_cast <const Ta **>(a ),
1767
- matrix_info-> ld_info , reinterpret_cast <const Tb **>(b ), matrix_info->ld_info + 1 ,
1768
- matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 ,
1769
- &(matrix_info->groupsize_info ));
1761
+ matrix_info->size_info + 2 , reinterpret_cast <Ts *>(matrix_info-> value_info ),
1762
+ reinterpret_cast <const Ta **>(a ), matrix_info->ld_info , reinterpret_cast < const Tb **>(b) ,
1763
+ matrix_info->ld_info + 1 , reinterpret_cast <Ts *>( matrix_info->value_info + 1 ) ,
1764
+ reinterpret_cast <Tc **>(c), matrix_info-> ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
1770
1765
#else
1771
1766
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
1772
1767
q, matrix_info->transpose_info , matrix_info->transpose_info + 1 , matrix_info->size_info ,
1773
- matrix_info->size_info + 1 , matrix_info->size_info + 2 , matrix_info->value_info ,
1768
+ matrix_info->size_info + 1 , matrix_info->size_info + 2 , reinterpret_cast <Ts *>( matrix_info->value_info ) ,
1774
1769
reinterpret_cast <const Ta **>(a), matrix_info->ld_info , reinterpret_cast <const Tb **>(b),
1775
- matrix_info->ld_info + 1 , matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c ),
1776
- matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
1770
+ matrix_info->ld_info + 1 , reinterpret_cast <Ts *>( matrix_info->value_info + 1 ),
1771
+ reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
1777
1772
#endif
1778
-
1779
- q.submit ([&](sycl::handler &cgh)
1780
- {
1781
- cgh.depends_on (e);
1782
- cgh.host_task ([=] { std::free (matrix_info); }); });
1783
1773
}
1784
1774
1785
1775
template <class Ta , class Tb , class Tc , class Ts >
@@ -2422,25 +2412,11 @@ namespace dpct
2422
2412
// / \param [in] ldc Leading dimension of C.
2423
2413
// / \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2424
2414
// / \param [in] scaling_type Data type of the scaling factors.
2425
- inline void gemm_batch (sycl::queue &q, oneapi::mkl::transpose a_trans,
2426
- oneapi::mkl::transpose b_trans, int m, int n, int k,
2427
- const void *alpha, const void *a[],
2428
- library_data_t a_type, int lda, const void *b[],
2429
- library_data_t b_type, int ldb, const void *beta,
2430
- void *c[], library_data_t c_type, int ldc,
2431
- int batch_size, library_data_t scaling_type)
2432
- {
2433
- if (scaling_type == library_data_t ::real_float &&
2434
- c_type == library_data_t ::complex_float)
2435
- {
2436
- scaling_type = library_data_t ::complex_float;
2437
- }
2438
- else if (scaling_type == library_data_t ::real_double &&
2439
- c_type == library_data_t ::complex_double)
2440
- {
2441
- scaling_type = library_data_t ::complex_double;
2442
- }
2443
-
2415
+ inline void gemm_batch (sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2416
+ int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
2417
+ const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
2418
+ library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
2419
+ matrix_info_t <float > * matrix_info) {
2444
2420
std::uint64_t key =
2445
2421
detail::get_type_combination_id (a_type, b_type, c_type, scaling_type);
2446
2422
switch (key)
@@ -2449,68 +2425,41 @@ namespace dpct
2449
2425
library_data_t ::real_float, library_data_t ::real_float,
2450
2426
library_data_t ::real_float, library_data_t ::real_float):
2451
2427
{
2452
- detail::gemm_batch_impl<float , float , float , float >(
2453
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2454
- batch_size);
2428
+ detail::gemm_batch_impl<float , float , float , float >(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2429
+ beta, c, ldc, batch_size, matrix_info);
2455
2430
break ;
2456
2431
}
2457
2432
case detail::get_type_combination_id (
2458
2433
library_data_t ::real_double, library_data_t ::real_double,
2459
2434
library_data_t ::real_double, library_data_t ::real_double):
2460
2435
{
2461
- detail::gemm_batch_impl<double , double , double , double >(
2462
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2463
- batch_size);
2464
- break ;
2465
- }
2466
- case detail::get_type_combination_id (
2467
- library_data_t ::complex_float, library_data_t ::complex_float,
2468
- library_data_t ::complex_float, library_data_t ::complex_float):
2469
- {
2470
- detail::gemm_batch_impl<std::complex<float >, std::complex<float >,
2471
- std::complex<float >, std::complex<float >>(
2472
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2473
- batch_size);
2474
- break ;
2475
- }
2476
- case detail::get_type_combination_id (
2477
- library_data_t ::complex_double, library_data_t ::complex_double,
2478
- library_data_t ::complex_double, library_data_t ::complex_double):
2479
- {
2480
- detail::gemm_batch_impl<std::complex<double >, std::complex<double >,
2481
- std::complex<double >, std::complex<double >>(
2482
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2483
- batch_size);
2436
+ detail::gemm_batch_impl<double , double , double , double >(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2437
+ beta, c, ldc, batch_size, matrix_info);
2484
2438
break ;
2485
2439
}
2486
2440
case detail::get_type_combination_id (
2487
2441
library_data_t ::real_half, library_data_t ::real_half,
2488
2442
library_data_t ::real_half, library_data_t ::real_half):
2489
2443
{
2490
- detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
2491
- sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
2492
- a, lda, b, ldb, beta, c, ldc,
2493
- batch_size);
2444
+ detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2445
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2494
2446
break ;
2495
2447
}
2496
2448
#ifdef __INTEL_MKL__
2497
2449
case detail::get_type_combination_id (
2498
2450
library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
2499
2451
library_data_t ::real_bfloat16, library_data_t ::real_float):
2500
2452
{
2501
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
2502
- oneapi::mkl::bfloat16, float >(
2503
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2504
- batch_size);
2453
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float >(
2454
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2505
2455
break ;
2506
2456
}
2507
2457
case detail::get_type_combination_id (
2508
2458
library_data_t ::real_bfloat16, library_data_t ::real_bfloat16,
2509
2459
library_data_t ::real_float, library_data_t ::real_float):
2510
2460
{
2511
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float ,
2512
- float >(q, a_trans, b_trans, m, n, k, alpha, a, lda,
2513
- b, ldb, beta, c, ldc, batch_size);
2461
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float , float >(
2462
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2514
2463
break ;
2515
2464
}
2516
2465
#endif
@@ -2522,28 +2471,25 @@ namespace dpct
2522
2471
dpct::get_value (reinterpret_cast <const std::int32_t *>(alpha), q);
2523
2472
float beta_float =
2524
2473
dpct::get_value (reinterpret_cast <const std::int32_t *>(beta), q);
2525
- detail::gemm_batch_impl<std::int8_t , std::int8_t , std::int32_t ,
2526
- float >(q, a_trans, b_trans, m, n, k, &alpha_float,
2527
- a, lda, b, ldb, &beta_float, c, ldc,
2528
- batch_size);
2474
+ detail::gemm_batch_impl<std::int8_t , std::int8_t , std::int32_t , float >(
2475
+ q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
2476
+ matrix_info);
2529
2477
break ;
2530
2478
}
2531
2479
case detail::get_type_combination_id (
2532
2480
library_data_t ::real_int8, library_data_t ::real_int8,
2533
2481
library_data_t ::real_float, library_data_t ::real_float):
2534
2482
{
2535
2483
detail::gemm_batch_impl<std::int8_t , std::int8_t , float , float >(
2536
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2537
- batch_size);
2484
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2538
2485
break ;
2539
2486
}
2540
2487
case detail::get_type_combination_id (
2541
2488
library_data_t ::real_half, library_data_t ::real_half,
2542
2489
library_data_t ::real_float, library_data_t ::real_float):
2543
2490
{
2544
2491
detail::gemm_batch_impl<sycl::half, sycl::half, float , float >(
2545
- q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2546
- batch_size);
2492
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2547
2493
break ;
2548
2494
}
2549
2495
case detail::get_type_combination_id (
@@ -2557,8 +2503,7 @@ namespace dpct
2557
2503
sycl::half alpha_half (alpha_value);
2558
2504
sycl::half beta_half (beta_value);
2559
2505
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2560
- q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
2561
- batch_size);
2506
+ q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
2562
2507
break ;
2563
2508
}
2564
2509
default :
0 commit comments