Skip to content

Commit 1de15fa

Browse files
committed
Fix bugs in host_mem pool implementation.
First fully working implementation in benchmark and llama-cli. Still issues in backend tests. Signed-off-by: nscipione <[email protected]>
1 parent 8da28db commit 1de15fa

File tree

3 files changed

+120
-56
lines changed

3 files changed

+120
-56
lines changed

ggml/src/ggml-sycl/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ void* ggml_sycl_host_malloc(size_t size) try {
2929
dpct::err0 err = CHECK_TRY_ERROR(
3030
ptr = (void*)sycl::malloc_host(size, dpct::get_in_order_queue()));
3131

32-
printf("Luigi\n");
32+
//printf("Luigi\n");
3333
if (err != 0) {
3434
// clear the error
3535
GGML_LOG_ERROR("WARNING: failed to allocate %.2f MB of pinned memory: %s\n", size / 1024.0 / 1024.0, "syclGetErrorString is not supported");

ggml/src/ggml-sycl/dpct/helper.hpp

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <syclcompat/math.hpp>
1919
#include <oneapi/mkl.hpp>
2020
#include <map>
21+
#include <cassert>
2122

2223
#include "ggml-sycl.h"
2324
#include "ggml.h"
@@ -88,6 +89,16 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
8889
return device_type.str();
8990
}
9091

92+
template<typename Ts>
93+
struct matrix_info_t
94+
{
95+
oneapi::mkl::transpose transpose_info[2];
96+
Ts value_info[2];
97+
std::int64_t size_info[3];
98+
std::int64_t ld_info[3];
99+
std::int64_t groupsize_info;
100+
};
101+
91102
namespace dpct
92103
{
93104
typedef sycl::queue *queue_ptr;
@@ -1737,27 +1748,16 @@ namespace dpct
17371748
oneapi::mkl::transpose b_trans, int m, int n, int k,
17381749
const void *alpha, const void **a, int lda,
17391750
const void **b, int ldb, const void *beta, void **c,
1740-
int ldc, int batch_size)
1751+
int ldc, int batch_size, matrix_info_t<double>* matrix_info)
17411752
{
1742-
struct matrix_info_t
1743-
{
1744-
oneapi::mkl::transpose transpose_info[2];
1745-
Ts value_info[2];
1746-
std::int64_t size_info[3];
1747-
std::int64_t ld_info[3];
1748-
std::int64_t groupsize_info;
1749-
};
17501753

17511754
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
17521755
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
17531756

1754-
//ggml_backend_sycl_host_buffer_type()->alloc_buffer;
1755-
auto tmp = ggml_backend_sycl_reg();
1756-
std::cout << "this is WARIO " << tmp->iface.get_name(tmp) << '\n';
1757+
//::matrix_info_t<Ts> *matrix_info =
1758+
//(::matrix_info_t<Ts> *)std::malloc(sizeof(matrix_info_t<Ts>));
1759+
//printf("test pointer %p alpha_value %f before\n", matrix_info, alpha_value);
17571760

1758-
1759-
matrix_info_t *matrix_info =
1760-
(matrix_info_t *)std::malloc(sizeof(matrix_info_t));
17611761
matrix_info->transpose_info[0] = a_trans;
17621762
matrix_info->transpose_info[1] = b_trans;
17631763
matrix_info->value_info[0] = alpha_value;
@@ -1770,13 +1770,15 @@ namespace dpct
17701770
matrix_info->ld_info[2] = ldc;
17711771
matrix_info->groupsize_info = batch_size;
17721772

1773+
//printf("test pointer %p alpha_value %f\n", matrix_info, matrix_info->value_info[0]);;
1774+
17731775
#ifdef GGML_SYCL_NVIDIA
17741776
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
17751777
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
17761778
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
1777-
matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
1779+
matrix_info->size_info + 2, reinterpret_cast<Ts*>(matrix_info->value_info), reinterpret_cast<const Ta **>(a),
17781780
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
1779-
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
1781+
reinterpret_cast<Ts*>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
17801782
&(matrix_info->groupsize_info));
17811783
#else
17821784
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
@@ -1786,11 +1788,14 @@ namespace dpct
17861788
matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
17871789
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
17881790
#endif
1791+
//printf("gemm_launched\n");
17891792

1793+
/*
17901794
q.submit([&](sycl::handler &cgh)
17911795
{
17921796
cgh.depends_on(e);
17931797
cgh.host_task([=] { std::free(matrix_info); }); });
1798+
*/
17941799
}
17951800

17961801
template <class Ta, class Tb, class Tc, class Ts>
@@ -2439,7 +2444,8 @@ namespace dpct
24392444
library_data_t a_type, int lda, const void *b[],
24402445
library_data_t b_type, int ldb, const void *beta,
24412446
void *c[], library_data_t c_type, int ldc,
2442-
int batch_size, library_data_t scaling_type)
2447+
int batch_size, library_data_t scaling_type,
2448+
matrix_info_t<double>* matrix_info)
24432449
{
24442450
if (scaling_type == library_data_t::real_float &&
24452451
c_type == library_data_t::complex_float)
@@ -2451,7 +2457,6 @@ namespace dpct
24512457
{
24522458
scaling_type = library_data_t::complex_double;
24532459
}
2454-
24552460
std::uint64_t key =
24562461
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
24572462
switch (key)
@@ -2462,7 +2467,7 @@ namespace dpct
24622467
{
24632468
detail::gemm_batch_impl<float, float, float, float>(
24642469
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2465-
batch_size);
2470+
batch_size, matrix_info);
24662471
break;
24672472
}
24682473
case detail::get_type_combination_id(
@@ -2471,17 +2476,18 @@ namespace dpct
24712476
{
24722477
detail::gemm_batch_impl<double, double, double, double>(
24732478
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2474-
batch_size);
2479+
batch_size, matrix_info);
24752480
break;
24762481
}
2482+
/*
24772483
case detail::get_type_combination_id(
24782484
library_data_t::complex_float, library_data_t::complex_float,
24792485
library_data_t::complex_float, library_data_t::complex_float):
24802486
{
24812487
detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
24822488
std::complex<float>, std::complex<float>>(
24832489
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2484-
batch_size);
2490+
batch_size, matrix_info);
24852491
break;
24862492
}
24872493
case detail::get_type_combination_id(
@@ -2491,17 +2497,18 @@ namespace dpct
24912497
detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
24922498
std::complex<double>, std::complex<double>>(
24932499
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2494-
batch_size);
2500+
batch_size, matrix_info);
24952501
break;
24962502
}
2503+
*/
24972504
case detail::get_type_combination_id(
24982505
library_data_t::real_half, library_data_t::real_half,
24992506
library_data_t::real_half, library_data_t::real_half):
25002507
{
25012508
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
25022509
sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
25032510
a, lda, b, ldb, beta, c, ldc,
2504-
batch_size);
2511+
batch_size, matrix_info);
25052512
break;
25062513
}
25072514
#ifdef __INTEL_MKL__
@@ -2512,7 +2519,7 @@ namespace dpct
25122519
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
25132520
oneapi::mkl::bfloat16, float>(
25142521
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2515-
batch_size);
2522+
batch_size, matrix_info);
25162523
break;
25172524
}
25182525
case detail::get_type_combination_id(
@@ -2521,7 +2528,7 @@ namespace dpct
25212528
{
25222529
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
25232530
float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
2524-
b, ldb, beta, c, ldc, batch_size);
2531+
b, ldb, beta, c, ldc, batch_size, matrix_info);
25252532
break;
25262533
}
25272534
#endif
@@ -2536,7 +2543,7 @@ namespace dpct
25362543
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
25372544
float>(q, a_trans, b_trans, m, n, k, &alpha_float,
25382545
a, lda, b, ldb, &beta_float, c, ldc,
2539-
batch_size);
2546+
batch_size, matrix_info);
25402547
break;
25412548
}
25422549
case detail::get_type_combination_id(
@@ -2545,7 +2552,7 @@ namespace dpct
25452552
{
25462553
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
25472554
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2548-
batch_size);
2555+
batch_size, matrix_info);
25492556
break;
25502557
}
25512558
case detail::get_type_combination_id(
@@ -2554,7 +2561,7 @@ namespace dpct
25542561
{
25552562
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
25562563
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2557-
batch_size);
2564+
batch_size, matrix_info);
25582565
break;
25592566
}
25602567
case detail::get_type_combination_id(
@@ -2569,7 +2576,7 @@ namespace dpct
25692576
sycl::half beta_half(beta_value);
25702577
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
25712578
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
2572-
batch_size);
2579+
batch_size, matrix_info);
25732580
break;
25742581
}
25752582
default:

0 commit comments

Comments
 (0)