Skip to content

Commit 6125929

Browse files
s-NickNeoZhangJianyu
authored andcommitted
SYCL: Introducing memory host pool (ggml-org#11251)
* Implement host pool for matrix_info Creating a new memory pool on the host to store memory location for matrix_info needed to launch gemm_batch from oneMKL/oneMath. Removing complex support in gemm_batch since it is not used in llama.cpp * Remove unnecessary headers and cast * Reorder member variable to avoid warning on initialization * Formatting * Remove unused variable * Address PR review feedback - remove warning --------- Signed-off-by: nscipione <[email protected]>
1 parent b8f6d2b commit 6125929

File tree

3 files changed

+141
-107
lines changed

3 files changed

+141
-107
lines changed

ggml/src/ggml-sycl/common.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,12 @@ struct ggml_backend_sycl_context {
311311
// pool
312312
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
313313

314+
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
315+
314316
static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
315317

318+
static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device);
319+
316320
ggml_sycl_pool & pool(int device) {
317321
if (pools[device] == nullptr) {
318322
pools[device] = new_pool_for_device(stream(device,0), device);
@@ -323,6 +327,15 @@ struct ggml_backend_sycl_context {
323327
ggml_sycl_pool & pool() {
324328
return pool(device);
325329
}
330+
331+
ggml_sycl_pool & host_pool(int device) {
332+
if (host_pools[device] == nullptr) {
333+
host_pools[device] = new_pool_for_host(stream(device, 0), device);
334+
}
335+
return *host_pools[device];
336+
}
337+
338+
ggml_sycl_pool & host_pool() { return host_pool(device); }
326339
};
327340

328341
static inline void exit_with_stack_print() {

ggml/src/ggml-sycl/dpct.hpp

Lines changed: 40 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
8282
return device_type.str();
8383
}
8484

85+
template <typename Ts> struct matrix_info_t {
86+
oneapi::mkl::transpose transpose_info[2];
87+
Ts value_info[2];
88+
std::int64_t size_info[3];
89+
std::int64_t ld_info[3];
90+
std::int64_t groupsize_info;
91+
};
92+
8593
namespace dpct
8694
{
8795
typedef sycl::queue *queue_ptr;
@@ -1737,26 +1745,13 @@ namespace dpct
17371745
};
17381746

17391747
template <class Ta, class Tb, class Tc, class Ts>
1740-
inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
1741-
oneapi::mkl::transpose b_trans, int m, int n, int k,
1742-
const void *alpha, const void **a, int lda,
1743-
const void **b, int ldb, const void *beta, void **c,
1744-
int ldc, int batch_size)
1745-
{
1746-
struct matrix_info_t
1747-
{
1748-
oneapi::mkl::transpose transpose_info[2];
1749-
Ts value_info[2];
1750-
std::int64_t size_info[3];
1751-
std::int64_t ld_info[3];
1752-
std::int64_t groupsize_info;
1753-
};
1754-
1748+
inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1749+
int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
1750+
int ldb, const void * beta, void ** c, int ldc, int batch_size,
1751+
matrix_info_t<float> * matrix_info) {
17551752
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
17561753
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
17571754

1758-
matrix_info_t *matrix_info =
1759-
(matrix_info_t *)std::malloc(sizeof(matrix_info_t));
17601755
matrix_info->transpose_info[0] = a_trans;
17611756
matrix_info->transpose_info[1] = b_trans;
17621757
matrix_info->value_info[0] = alpha_value;
@@ -1773,23 +1768,18 @@ namespace dpct
17731768
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
17741769
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
17751770
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
1776-
matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
1777-
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
1778-
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
1779-
&(matrix_info->groupsize_info));
1771+
matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
1772+
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
1773+
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
1774+
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
17801775
#else
17811776
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
17821777
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
1783-
matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
1778+
matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
17841779
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
1785-
matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
1786-
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1780+
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
1781+
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
17871782
#endif
1788-
1789-
q.submit([&](sycl::handler &cgh)
1790-
{
1791-
cgh.depends_on(e);
1792-
cgh.host_task([=] { std::free(matrix_info); }); });
17931783
}
17941784

17951785
template <class Ta, class Tb, class Tc, class Ts>
@@ -2427,25 +2417,11 @@ namespace dpct
24272417
/// \param [in] ldc Leading dimension of C.
24282418
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
24292419
/// \param [in] scaling_type Data type of the scaling factors.
2430-
inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
2431-
oneapi::mkl::transpose b_trans, int m, int n, int k,
2432-
const void *alpha, const void *a[],
2433-
library_data_t a_type, int lda, const void *b[],
2434-
library_data_t b_type, int ldb, const void *beta,
2435-
void *c[], library_data_t c_type, int ldc,
2436-
int batch_size, library_data_t scaling_type)
2437-
{
2438-
if (scaling_type == library_data_t::real_float &&
2439-
c_type == library_data_t::complex_float)
2440-
{
2441-
scaling_type = library_data_t::complex_float;
2442-
}
2443-
else if (scaling_type == library_data_t::real_double &&
2444-
c_type == library_data_t::complex_double)
2445-
{
2446-
scaling_type = library_data_t::complex_double;
2447-
}
2448-
2420+
inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2421+
int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
2422+
const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
2423+
library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
2424+
matrix_info_t<float> * matrix_info) {
24492425
std::uint64_t key =
24502426
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
24512427
switch (key)
@@ -2454,68 +2430,41 @@ namespace dpct
24542430
library_data_t::real_float, library_data_t::real_float,
24552431
library_data_t::real_float, library_data_t::real_float):
24562432
{
2457-
detail::gemm_batch_impl<float, float, float, float>(
2458-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2459-
batch_size);
2433+
detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2434+
beta, c, ldc, batch_size, matrix_info);
24602435
break;
24612436
}
24622437
case detail::get_type_combination_id(
24632438
library_data_t::real_double, library_data_t::real_double,
24642439
library_data_t::real_double, library_data_t::real_double):
24652440
{
2466-
detail::gemm_batch_impl<double, double, double, double>(
2467-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2468-
batch_size);
2469-
break;
2470-
}
2471-
case detail::get_type_combination_id(
2472-
library_data_t::complex_float, library_data_t::complex_float,
2473-
library_data_t::complex_float, library_data_t::complex_float):
2474-
{
2475-
detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
2476-
std::complex<float>, std::complex<float>>(
2477-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2478-
batch_size);
2479-
break;
2480-
}
2481-
case detail::get_type_combination_id(
2482-
library_data_t::complex_double, library_data_t::complex_double,
2483-
library_data_t::complex_double, library_data_t::complex_double):
2484-
{
2485-
detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
2486-
std::complex<double>, std::complex<double>>(
2487-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2488-
batch_size);
2441+
detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2442+
beta, c, ldc, batch_size, matrix_info);
24892443
break;
24902444
}
24912445
case detail::get_type_combination_id(
24922446
library_data_t::real_half, library_data_t::real_half,
24932447
library_data_t::real_half, library_data_t::real_half):
24942448
{
2495-
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
2496-
sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
2497-
a, lda, b, ldb, beta, c, ldc,
2498-
batch_size);
2449+
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2450+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
24992451
break;
25002452
}
25012453
#ifdef __INTEL_MKL__
25022454
case detail::get_type_combination_id(
25032455
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
25042456
library_data_t::real_bfloat16, library_data_t::real_float):
25052457
{
2506-
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
2507-
oneapi::mkl::bfloat16, float>(
2508-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2509-
batch_size);
2458+
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2459+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25102460
break;
25112461
}
25122462
case detail::get_type_combination_id(
25132463
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
25142464
library_data_t::real_float, library_data_t::real_float):
25152465
{
2516-
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
2517-
float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
2518-
b, ldb, beta, c, ldc, batch_size);
2466+
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2467+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25192468
break;
25202469
}
25212470
case detail::get_type_combination_id(
@@ -2526,28 +2475,25 @@ namespace dpct
25262475
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
25272476
float beta_float =
25282477
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
2529-
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
2530-
float>(q, a_trans, b_trans, m, n, k, &alpha_float,
2531-
a, lda, b, ldb, &beta_float, c, ldc,
2532-
batch_size);
2478+
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(
2479+
q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
2480+
matrix_info);
25332481
break;
25342482
}
25352483
case detail::get_type_combination_id(
25362484
library_data_t::real_int8, library_data_t::real_int8,
25372485
library_data_t::real_float, library_data_t::real_float):
25382486
{
25392487
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
2540-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2541-
batch_size);
2488+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25422489
break;
25432490
}
25442491
case detail::get_type_combination_id(
25452492
library_data_t::real_half, library_data_t::real_half,
25462493
library_data_t::real_float, library_data_t::real_float):
25472494
{
25482495
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
2549-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2550-
batch_size);
2496+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25512497
break;
25522498
}
25532499
#endif
@@ -2562,8 +2508,7 @@ namespace dpct
25622508
sycl::half alpha_half(alpha_value);
25632509
sycl::half beta_half(beta_value);
25642510
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2565-
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
2566-
batch_size);
2511+
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
25672512
break;
25682513
}
25692514
default:

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,12 +1010,91 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
10101010
}
10111011
};
10121012

1013-
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device_id) {
1013+
struct ggml_sycl_pool_host : public ggml_sycl_pool {
1014+
queue_ptr qptr;
1015+
int device;
1016+
1017+
inline static int counter{ 0 };
1018+
1019+
struct ggml_sycl_buffer {
1020+
void * ptr = nullptr;
1021+
size_t size = 0;
1022+
};
1023+
1024+
// Set arbitrarly to 64
1025+
static constexpr int MAX_POOL_SIZE{ 64 };
1026+
std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
1027+
size_t pool_size = 0;
1028+
1029+
explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
1030+
1031+
~ggml_sycl_pool_host() {
1032+
for (int i = 0; i < MAX_POOL_SIZE; ++i) {
1033+
ggml_sycl_buffer & b = buffer_pool[i];
1034+
if (b.ptr != nullptr) {
1035+
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
1036+
b.ptr = nullptr;
1037+
pool_size -= b.size;
1038+
b.size = 0;
1039+
}
1040+
}
1041+
counter = 0;
1042+
}
1043+
1044+
void * alloc(size_t size, size_t * actual_size) override {
1045+
if (counter == MAX_POOL_SIZE) {
1046+
ggml_sycl_buffer b = buffer_pool[0];
1047+
void * ptr = b.ptr;
1048+
*actual_size = b.size;
1049+
counter = 1;
1050+
return ptr;
1051+
}
1052+
ggml_sycl_buffer & b = buffer_pool[counter];
1053+
1054+
if (b.ptr == nullptr) {
1055+
void * ptr;
1056+
1057+
SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
1058+
if (!ptr) {
1059+
GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
1060+
return nullptr;
1061+
}
1062+
pool_size += size;
1063+
*actual_size = size;
1064+
counter = counter + 1;
1065+
return ptr;
1066+
} else {
1067+
++counter;
1068+
b.size = size;
1069+
return b.ptr;
1070+
}
1071+
}
1072+
1073+
void free(void * ptr, size_t size) override {
1074+
// if the pool is not completed add the pointer to it in place of the first nullptr found.
1075+
// Otherwise do nothing, pointers will be freed once the pool is deallocated.
1076+
for (int i = 0; i < MAX_POOL_SIZE; ++i) {
1077+
ggml_sycl_buffer & b = buffer_pool[i];
1078+
if (b.ptr == nullptr) {
1079+
b.ptr = ptr;
1080+
b.size = size;
1081+
return;
1082+
}
1083+
}
1084+
}
1085+
};
1086+
1087+
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) {
1088+
// return pool for the host to speed up memory management
1089+
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device));
1090+
}
1091+
1092+
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
10141093
// TBD: NO VMM support
1015-
// if (ggml_sycl_info().devices[device_id].vmm) {
1016-
// return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_vmm(device_id));
1094+
// if (ggml_sycl_info().devices[device].vmm) {
1095+
// return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_vmm(device));
10171096
// }
1018-
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device_id));
1097+
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device));
10191098
}
10201099

10211100
// TBD pool with virtual memory management
@@ -3230,6 +3309,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
32303309

32313310
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
32323311
ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
3312+
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
32333313

32343314
sycl::range<3> block_dims(1, ne12, ne13);
32353315
/*
@@ -3258,14 +3338,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
32583338
});
32593339
}
32603340
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
3261-
*main_stream, oneapi::mkl::transpose::trans,
3262-
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3263-
(const void **)(ptrs_src.get() + 0 * ne23),
3264-
dpct::library_data_t::real_half, nb01 / nb00,
3265-
(const void **)(ptrs_src.get() + 1 * ne23),
3266-
dpct::library_data_t::real_half, nb11 / nb10, beta,
3267-
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
3268-
cu_compute_type)));
3341+
*main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3342+
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
3343+
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
3344+
(void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
32693345
}
32703346
}
32713347
catch (sycl::exception const &exc) {

0 commit comments

Comments
 (0)