18
18
#include < syclcompat/math.hpp>
19
19
#include < oneapi/mkl.hpp>
20
20
#include < map>
21
+ #include < cassert>
21
22
22
23
#include " ggml-sycl.h"
23
24
#include " ggml.h"
@@ -88,6 +89,16 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
88
89
return device_type.str ();
89
90
}
90
91
92
+ template <typename Ts>
93
+ struct matrix_info_t
94
+ {
95
+ oneapi::mkl::transpose transpose_info[2 ];
96
+ Ts value_info[2 ];
97
+ std::int64_t size_info[3 ];
98
+ std::int64_t ld_info[3 ];
99
+ std::int64_t groupsize_info;
100
+ };
101
+
91
102
namespace dpct
92
103
{
93
104
typedef sycl::queue *queue_ptr;
@@ -1737,27 +1748,16 @@ namespace dpct
1737
1748
oneapi::mkl::transpose b_trans, int m, int n, int k,
1738
1749
const void *alpha, const void **a, int lda,
1739
1750
const void **b, int ldb, const void *beta, void **c,
1740
- int ldc, int batch_size)
1751
+ int ldc, int batch_size, matrix_info_t < double >* matrix_info )
1741
1752
{
1742
- struct matrix_info_t
1743
- {
1744
- oneapi::mkl::transpose transpose_info[2 ];
1745
- Ts value_info[2 ];
1746
- std::int64_t size_info[3 ];
1747
- std::int64_t ld_info[3 ];
1748
- std::int64_t groupsize_info;
1749
- };
1750
1753
1751
1754
Ts alpha_value = dpct::get_value (reinterpret_cast <const Ts *>(alpha), q);
1752
1755
Ts beta_value = dpct::get_value (reinterpret_cast <const Ts *>(beta), q);
1753
1756
1754
- // ggml_backend_sycl_host_buffer_type()->alloc_buffer;
1755
- auto tmp = ggml_backend_sycl_reg ( );
1756
- std::cout << " this is WARIO " << tmp-> iface . get_name (tmp) << ' \n ' ;
1757
+ // ::matrix_info_t<Ts> *matrix_info =
1758
+ // (::matrix_info_t<Ts> *)std::malloc(sizeof(matrix_info_t<Ts>) );
1759
+ // printf("test pointer %p alpha_value %f before\n", matrix_info, alpha_value) ;
1757
1760
1758
-
1759
- matrix_info_t *matrix_info =
1760
- (matrix_info_t *)std::malloc (sizeof (matrix_info_t ));
1761
1761
matrix_info->transpose_info [0 ] = a_trans;
1762
1762
matrix_info->transpose_info [1 ] = b_trans;
1763
1763
matrix_info->value_info [0 ] = alpha_value;
@@ -1770,13 +1770,15 @@ namespace dpct
1770
1770
matrix_info->ld_info [2 ] = ldc;
1771
1771
matrix_info->groupsize_info = batch_size;
1772
1772
1773
+ // printf("test pointer %p alpha_value %f\n", matrix_info, matrix_info->value_info[0]);;
1774
+
1773
1775
#ifdef GGML_SYCL_NVIDIA
1774
1776
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
1775
1777
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info ,
1776
1778
matrix_info->transpose_info + 1 , matrix_info->size_info , matrix_info->size_info + 1 ,
1777
- matrix_info->size_info + 2 , matrix_info->value_info , reinterpret_cast <const Ta **>(a),
1779
+ matrix_info->size_info + 2 , reinterpret_cast <Ts*>( matrix_info->value_info ) , reinterpret_cast <const Ta **>(a),
1778
1780
matrix_info->ld_info , reinterpret_cast <const Tb **>(b), matrix_info->ld_info + 1 ,
1779
- matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 ,
1781
+ reinterpret_cast <Ts*>( matrix_info->value_info + 1 ) , reinterpret_cast <Tc **>(c), matrix_info->ld_info + 2 , 1 ,
1780
1782
&(matrix_info->groupsize_info ));
1781
1783
#else
1782
1784
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch (
@@ -1786,11 +1788,14 @@ namespace dpct
1786
1788
matrix_info->ld_info + 1 , matrix_info->value_info + 1 , reinterpret_cast <Tc **>(c),
1787
1789
matrix_info->ld_info + 2 , 1 , &(matrix_info->groupsize_info ));
1788
1790
#endif
1791
+ // printf("gemm_launched\n");
1789
1792
1793
+ /*
1790
1794
q.submit([&](sycl::handler &cgh)
1791
1795
{
1792
1796
cgh.depends_on(e);
1793
1797
cgh.host_task([=] { std::free(matrix_info); }); });
1798
+ */
1794
1799
}
1795
1800
1796
1801
template <class Ta , class Tb , class Tc , class Ts >
@@ -2439,7 +2444,8 @@ namespace dpct
2439
2444
library_data_t a_type, int lda, const void *b[],
2440
2445
library_data_t b_type, int ldb, const void *beta,
2441
2446
void *c[], library_data_t c_type, int ldc,
2442
- int batch_size, library_data_t scaling_type)
2447
+ int batch_size, library_data_t scaling_type,
2448
+ matrix_info_t <double >* matrix_info)
2443
2449
{
2444
2450
if (scaling_type == library_data_t ::real_float &&
2445
2451
c_type == library_data_t ::complex_float)
@@ -2451,7 +2457,6 @@ namespace dpct
2451
2457
{
2452
2458
scaling_type = library_data_t ::complex_double;
2453
2459
}
2454
-
2455
2460
std::uint64_t key =
2456
2461
detail::get_type_combination_id (a_type, b_type, c_type, scaling_type);
2457
2462
switch (key)
@@ -2462,7 +2467,7 @@ namespace dpct
2462
2467
{
2463
2468
detail::gemm_batch_impl<float , float , float , float >(
2464
2469
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2465
- batch_size);
2470
+ batch_size, matrix_info );
2466
2471
break ;
2467
2472
}
2468
2473
case detail::get_type_combination_id (
@@ -2471,17 +2476,18 @@ namespace dpct
2471
2476
{
2472
2477
detail::gemm_batch_impl<double , double , double , double >(
2473
2478
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2474
- batch_size);
2479
+ batch_size, matrix_info );
2475
2480
break ;
2476
2481
}
2482
+ /*
2477
2483
case detail::get_type_combination_id(
2478
2484
library_data_t::complex_float, library_data_t::complex_float,
2479
2485
library_data_t::complex_float, library_data_t::complex_float):
2480
2486
{
2481
2487
detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
2482
2488
std::complex<float>, std::complex<float>>(
2483
2489
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2484
- batch_size);
2490
+ batch_size, matrix_info );
2485
2491
break;
2486
2492
}
2487
2493
case detail::get_type_combination_id(
@@ -2491,17 +2497,18 @@ namespace dpct
2491
2497
detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
2492
2498
std::complex<double>, std::complex<double>>(
2493
2499
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2494
- batch_size);
2500
+ batch_size, matrix_info );
2495
2501
break;
2496
2502
}
2503
+ */
2497
2504
case detail::get_type_combination_id (
2498
2505
library_data_t ::real_half, library_data_t ::real_half,
2499
2506
library_data_t ::real_half, library_data_t ::real_half):
2500
2507
{
2501
2508
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
2502
2509
sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
2503
2510
a, lda, b, ldb, beta, c, ldc,
2504
- batch_size);
2511
+ batch_size, matrix_info );
2505
2512
break ;
2506
2513
}
2507
2514
#ifdef __INTEL_MKL__
@@ -2512,7 +2519,7 @@ namespace dpct
2512
2519
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
2513
2520
oneapi::mkl::bfloat16, float >(
2514
2521
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2515
- batch_size);
2522
+ batch_size, matrix_info );
2516
2523
break ;
2517
2524
}
2518
2525
case detail::get_type_combination_id (
@@ -2521,7 +2528,7 @@ namespace dpct
2521
2528
{
2522
2529
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float ,
2523
2530
float >(q, a_trans, b_trans, m, n, k, alpha, a, lda,
2524
- b, ldb, beta, c, ldc, batch_size);
2531
+ b, ldb, beta, c, ldc, batch_size, matrix_info );
2525
2532
break ;
2526
2533
}
2527
2534
#endif
@@ -2536,7 +2543,7 @@ namespace dpct
2536
2543
detail::gemm_batch_impl<std::int8_t , std::int8_t , std::int32_t ,
2537
2544
float >(q, a_trans, b_trans, m, n, k, &alpha_float,
2538
2545
a, lda, b, ldb, &beta_float, c, ldc,
2539
- batch_size);
2546
+ batch_size, matrix_info );
2540
2547
break ;
2541
2548
}
2542
2549
case detail::get_type_combination_id (
@@ -2545,7 +2552,7 @@ namespace dpct
2545
2552
{
2546
2553
detail::gemm_batch_impl<std::int8_t , std::int8_t , float , float >(
2547
2554
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2548
- batch_size);
2555
+ batch_size, matrix_info );
2549
2556
break ;
2550
2557
}
2551
2558
case detail::get_type_combination_id (
@@ -2554,7 +2561,7 @@ namespace dpct
2554
2561
{
2555
2562
detail::gemm_batch_impl<sycl::half, sycl::half, float , float >(
2556
2563
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2557
- batch_size);
2564
+ batch_size, matrix_info );
2558
2565
break ;
2559
2566
}
2560
2567
case detail::get_type_combination_id (
@@ -2569,7 +2576,7 @@ namespace dpct
2569
2576
sycl::half beta_half (beta_value);
2570
2577
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2571
2578
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
2572
- batch_size);
2579
+ batch_size, matrix_info );
2573
2580
break ;
2574
2581
}
2575
2582
default :
0 commit comments