Skip to content

Commit 99487b5

Browse files
authored
SYCL: Introducing memory host pool (#11251)
* Implement host pool for matrix_info Creating a new memory pool on the host to store memory location for matrix_info needed to launch gemm_batch from oneMKL/oneMath. Removing complex support in gemm_batch since it is not used in llama.cpp * Remove unnecessary headers and cast * Reorder member variable to avoid warning on initialization * Formatting * Remove unused variable * Address PR review feedback - remove warning --------- Signed-off-by: nscipione <[email protected]>
1 parent a1649cc commit 99487b5

File tree

3 files changed

+137
-103
lines changed

3 files changed

+137
-103
lines changed

ggml/src/ggml-sycl/common.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,12 @@ struct ggml_backend_sycl_context {
333333
// pool
334334
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
335335

336+
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
337+
336338
static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
337339

340+
static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device);
341+
338342
ggml_sycl_pool & pool(int device) {
339343
if (pools[device] == nullptr) {
340344
pools[device] = new_pool_for_device(stream(device,0), device);
@@ -345,6 +349,15 @@ struct ggml_backend_sycl_context {
345349
ggml_sycl_pool & pool() {
346350
return pool(device);
347351
}
352+
353+
ggml_sycl_pool & host_pool(int device) {
354+
if (host_pools[device] == nullptr) {
355+
host_pools[device] = new_pool_for_host(stream(device, 0), device);
356+
}
357+
return *host_pools[device];
358+
}
359+
360+
ggml_sycl_pool & host_pool() { return host_pool(device); }
348361
};
349362

350363
// common device functions

ggml/src/ggml-sycl/dpct/helper.hpp

Lines changed: 40 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
8282
return device_type.str();
8383
}
8484

85+
template <typename Ts> struct matrix_info_t {
86+
oneapi::mkl::transpose transpose_info[2];
87+
Ts value_info[2];
88+
std::int64_t size_info[3];
89+
std::int64_t ld_info[3];
90+
std::int64_t groupsize_info;
91+
};
92+
8593
namespace dpct
8694
{
8795
typedef sycl::queue *queue_ptr;
@@ -1727,26 +1735,13 @@ namespace dpct
17271735
};
17281736

17291737
template <class Ta, class Tb, class Tc, class Ts>
1730-
inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
1731-
oneapi::mkl::transpose b_trans, int m, int n, int k,
1732-
const void *alpha, const void **a, int lda,
1733-
const void **b, int ldb, const void *beta, void **c,
1734-
int ldc, int batch_size)
1735-
{
1736-
struct matrix_info_t
1737-
{
1738-
oneapi::mkl::transpose transpose_info[2];
1739-
Ts value_info[2];
1740-
std::int64_t size_info[3];
1741-
std::int64_t ld_info[3];
1742-
std::int64_t groupsize_info;
1743-
};
1744-
1738+
inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1739+
int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
1740+
int ldb, const void * beta, void ** c, int ldc, int batch_size,
1741+
matrix_info_t<float> * matrix_info) {
17451742
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
17461743
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
17471744

1748-
matrix_info_t *matrix_info =
1749-
(matrix_info_t *)std::malloc(sizeof(matrix_info_t));
17501745
matrix_info->transpose_info[0] = a_trans;
17511746
matrix_info->transpose_info[1] = b_trans;
17521747
matrix_info->value_info[0] = alpha_value;
@@ -1763,23 +1758,18 @@ namespace dpct
17631758
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
17641759
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
17651760
matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
1766-
matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
1767-
matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
1768-
matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
1769-
&(matrix_info->groupsize_info));
1761+
matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
1762+
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
1763+
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
1764+
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
17701765
#else
17711766
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
17721767
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
1773-
matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
1768+
matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
17741769
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
1775-
matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
1776-
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1770+
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
1771+
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
17771772
#endif
1778-
1779-
q.submit([&](sycl::handler &cgh)
1780-
{
1781-
cgh.depends_on(e);
1782-
cgh.host_task([=] { std::free(matrix_info); }); });
17831773
}
17841774

17851775
template <class Ta, class Tb, class Tc, class Ts>
@@ -2422,25 +2412,11 @@ namespace dpct
24222412
/// \param [in] ldc Leading dimension of C.
24232413
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
24242414
/// \param [in] scaling_type Data type of the scaling factors.
2425-
inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
2426-
oneapi::mkl::transpose b_trans, int m, int n, int k,
2427-
const void *alpha, const void *a[],
2428-
library_data_t a_type, int lda, const void *b[],
2429-
library_data_t b_type, int ldb, const void *beta,
2430-
void *c[], library_data_t c_type, int ldc,
2431-
int batch_size, library_data_t scaling_type)
2432-
{
2433-
if (scaling_type == library_data_t::real_float &&
2434-
c_type == library_data_t::complex_float)
2435-
{
2436-
scaling_type = library_data_t::complex_float;
2437-
}
2438-
else if (scaling_type == library_data_t::real_double &&
2439-
c_type == library_data_t::complex_double)
2440-
{
2441-
scaling_type = library_data_t::complex_double;
2442-
}
2443-
2415+
inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2416+
int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
2417+
const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
2418+
library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
2419+
matrix_info_t<float> * matrix_info) {
24442420
std::uint64_t key =
24452421
detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
24462422
switch (key)
@@ -2449,68 +2425,41 @@ namespace dpct
24492425
library_data_t::real_float, library_data_t::real_float,
24502426
library_data_t::real_float, library_data_t::real_float):
24512427
{
2452-
detail::gemm_batch_impl<float, float, float, float>(
2453-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2454-
batch_size);
2428+
detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2429+
beta, c, ldc, batch_size, matrix_info);
24552430
break;
24562431
}
24572432
case detail::get_type_combination_id(
24582433
library_data_t::real_double, library_data_t::real_double,
24592434
library_data_t::real_double, library_data_t::real_double):
24602435
{
2461-
detail::gemm_batch_impl<double, double, double, double>(
2462-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2463-
batch_size);
2464-
break;
2465-
}
2466-
case detail::get_type_combination_id(
2467-
library_data_t::complex_float, library_data_t::complex_float,
2468-
library_data_t::complex_float, library_data_t::complex_float):
2469-
{
2470-
detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
2471-
std::complex<float>, std::complex<float>>(
2472-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2473-
batch_size);
2474-
break;
2475-
}
2476-
case detail::get_type_combination_id(
2477-
library_data_t::complex_double, library_data_t::complex_double,
2478-
library_data_t::complex_double, library_data_t::complex_double):
2479-
{
2480-
detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
2481-
std::complex<double>, std::complex<double>>(
2482-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2483-
batch_size);
2436+
detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2437+
beta, c, ldc, batch_size, matrix_info);
24842438
break;
24852439
}
24862440
case detail::get_type_combination_id(
24872441
library_data_t::real_half, library_data_t::real_half,
24882442
library_data_t::real_half, library_data_t::real_half):
24892443
{
2490-
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
2491-
sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
2492-
a, lda, b, ldb, beta, c, ldc,
2493-
batch_size);
2444+
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2445+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
24942446
break;
24952447
}
24962448
#ifdef __INTEL_MKL__
24972449
case detail::get_type_combination_id(
24982450
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
24992451
library_data_t::real_bfloat16, library_data_t::real_float):
25002452
{
2501-
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
2502-
oneapi::mkl::bfloat16, float>(
2503-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2504-
batch_size);
2453+
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2454+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25052455
break;
25062456
}
25072457
case detail::get_type_combination_id(
25082458
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
25092459
library_data_t::real_float, library_data_t::real_float):
25102460
{
2511-
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
2512-
float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
2513-
b, ldb, beta, c, ldc, batch_size);
2461+
detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2462+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25142463
break;
25152464
}
25162465
#endif
@@ -2522,28 +2471,25 @@ namespace dpct
25222471
dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
25232472
float beta_float =
25242473
dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
2525-
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
2526-
float>(q, a_trans, b_trans, m, n, k, &alpha_float,
2527-
a, lda, b, ldb, &beta_float, c, ldc,
2528-
batch_size);
2474+
detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(
2475+
q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
2476+
matrix_info);
25292477
break;
25302478
}
25312479
case detail::get_type_combination_id(
25322480
library_data_t::real_int8, library_data_t::real_int8,
25332481
library_data_t::real_float, library_data_t::real_float):
25342482
{
25352483
detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
2536-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2537-
batch_size);
2484+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25382485
break;
25392486
}
25402487
case detail::get_type_combination_id(
25412488
library_data_t::real_half, library_data_t::real_half,
25422489
library_data_t::real_float, library_data_t::real_float):
25432490
{
25442491
detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
2545-
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
2546-
batch_size);
2492+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
25472493
break;
25482494
}
25492495
case detail::get_type_combination_id(
@@ -2557,8 +2503,7 @@ namespace dpct
25572503
sycl::half alpha_half(alpha_value);
25582504
sycl::half beta_half(beta_value);
25592505
detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2560-
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
2561-
batch_size);
2506+
q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
25622507
break;
25632508
}
25642509
default:

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 84 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,85 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
11731173
}
11741174
};
11751175

1176+
struct ggml_sycl_pool_host : public ggml_sycl_pool {
1177+
queue_ptr qptr;
1178+
int device;
1179+
1180+
inline static int counter{ 0 };
1181+
1182+
struct ggml_sycl_buffer {
1183+
void * ptr = nullptr;
1184+
size_t size = 0;
1185+
};
1186+
1187+
// Set arbitrarly to 64
1188+
static constexpr int MAX_POOL_SIZE{ 64 };
1189+
std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
1190+
size_t pool_size = 0;
1191+
1192+
explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
1193+
1194+
~ggml_sycl_pool_host() {
1195+
for (int i = 0; i < MAX_POOL_SIZE; ++i) {
1196+
ggml_sycl_buffer & b = buffer_pool[i];
1197+
if (b.ptr != nullptr) {
1198+
SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
1199+
b.ptr = nullptr;
1200+
pool_size -= b.size;
1201+
b.size = 0;
1202+
}
1203+
}
1204+
counter = 0;
1205+
}
1206+
1207+
void * alloc(size_t size, size_t * actual_size) override {
1208+
if (counter == MAX_POOL_SIZE) {
1209+
ggml_sycl_buffer b = buffer_pool[0];
1210+
void * ptr = b.ptr;
1211+
*actual_size = b.size;
1212+
counter = 1;
1213+
return ptr;
1214+
}
1215+
ggml_sycl_buffer & b = buffer_pool[counter];
1216+
1217+
if (b.ptr == nullptr) {
1218+
void * ptr;
1219+
1220+
SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
1221+
if (!ptr) {
1222+
GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
1223+
return nullptr;
1224+
}
1225+
pool_size += size;
1226+
*actual_size = size;
1227+
counter = counter + 1;
1228+
return ptr;
1229+
} else {
1230+
++counter;
1231+
b.size = size;
1232+
return b.ptr;
1233+
}
1234+
}
1235+
1236+
void free(void * ptr, size_t size) override {
1237+
// if the pool is not completed add the pointer to it in place of the first nullptr found.
1238+
// Otherwise do nothing, pointers will be freed once the pool is deallocated.
1239+
for (int i = 0; i < MAX_POOL_SIZE; ++i) {
1240+
ggml_sycl_buffer & b = buffer_pool[i];
1241+
if (b.ptr == nullptr) {
1242+
b.ptr = ptr;
1243+
b.size = size;
1244+
return;
1245+
}
1246+
}
1247+
}
1248+
};
1249+
1250+
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) {
1251+
// return pool for the host to speed up memory management
1252+
return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device));
1253+
}
1254+
11761255
std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
11771256
// TBD: NO VMM support
11781257
// if (ggml_sycl_info().devices[device].vmm) {
@@ -3363,6 +3442,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
33633442

33643443
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
33653444
ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
3445+
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
33663446

33673447
sycl::range<3> block_dims(1, ne12, ne13);
33683448
/*
@@ -3391,14 +3471,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
33913471
});
33923472
}
33933473
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
3394-
*main_stream, oneapi::mkl::transpose::trans,
3395-
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3396-
(const void **)(ptrs_src.get() + 0 * ne23),
3397-
dpct::library_data_t::real_half, nb01 / nb00,
3398-
(const void **)(ptrs_src.get() + 1 * ne23),
3399-
dpct::library_data_t::real_half, nb11 / nb10, beta,
3400-
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
3401-
cu_compute_type)));
3474+
*main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3475+
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
3476+
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
3477+
(void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
34023478
}
34033479
}
34043480
catch (sycl::exception const &exc) {

0 commit comments

Comments
 (0)