Skip to content

Commit 4e3db77

Browse files
committed
fix mul_mat_id to match the change of api
1 parent 917dc8c commit 4e3db77

File tree

1 file changed

+256
-54
lines changed

1 file changed

+256
-54
lines changed

ggml-sycl.cpp

Lines changed: 256 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2944,6 +2944,57 @@ namespace dpct
29442944
using shared_memory = detail::device_memory<T, shared, Dimension>;
29452945

29462946

2947+
template <typename T,
2948+
sycl::access::address_space addressSpace =
2949+
sycl::access::address_space::global_space,
2950+
sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2951+
sycl::memory_scope memoryScope = sycl::memory_scope::device>
2952+
inline T atomic_fetch_add(T *addr, T operand) {
2953+
auto atm =
2954+
sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);
2955+
return atm.fetch_add(operand);
2956+
}
2957+
2958+
template <sycl::access::address_space addressSpace =
2959+
sycl::access::address_space::global_space,
2960+
sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2961+
sycl::memory_scope memoryScope = sycl::memory_scope::device,
2962+
typename T1, typename T2>
2963+
inline T1 atomic_fetch_add(T1 *addr, T2 operand) {
2964+
auto atm =
2965+
sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);
2966+
return atm.fetch_add(operand);
2967+
}
2968+
2969+
template <typename T, sycl::access::address_space addressSpace =
2970+
sycl::access::address_space::global_space>
2971+
inline T atomic_fetch_add(T *addr, T operand,
2972+
sycl::memory_order memoryOrder) {
2973+
switch (memoryOrder) {
2974+
case sycl::memory_order::relaxed:
2975+
return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,
2976+
sycl::memory_scope::device>(addr, operand);
2977+
case sycl::memory_order::acq_rel:
2978+
return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,
2979+
sycl::memory_scope::device>(addr, operand);
2980+
case sycl::memory_order::seq_cst:
2981+
return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,
2982+
sycl::memory_scope::device>(addr, operand);
2983+
default:
2984+
assert(false && "Invalid memory_order for atomics. Valid memory_order for "
2985+
"atomics are: sycl::memory_order::relaxed, "
2986+
"sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!");
2987+
}
2988+
}
2989+
2990+
template <sycl::access::address_space addressSpace =
2991+
sycl::access::address_space::global_space,
2992+
typename T1, typename T2>
2993+
inline T1 atomic_fetch_add(T1 *addr, T2 operand,
2994+
sycl::memory_order memoryOrder) {
2995+
atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
2996+
}
2997+
29472998
} // COPY from DPCT head files
29482999

29493000
#define GGML_COMMON_DECL_SYCL
@@ -3060,6 +3111,7 @@ void ggml_sycl_get_device_description(int device, char * description, size_t d
30603111
bool ggml_backend_is_sycl(ggml_backend_t backend);
30613112
int ggml_backend_sycl_get_device(ggml_backend_t backend);
30623113
int get_main_device();
3114+
static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer);
30633115
void print_ggml_tensor(const char*name, struct ggml_tensor *src);
30643116
void log_tensor_with_cnt(const char* name, struct ggml_tensor * src, int stop_cnt);
30653117

@@ -15899,22 +15951,86 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
1589915951
}
1590015952
#endif
1590115953

15954+
struct mmid_row_mapping {
15955+
int32_t i1;
15956+
int32_t i2;
15957+
};
15958+
15959+
__dpct_inline__ static void k_copy_src1_to_contiguous(
15960+
const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
15961+
int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
15962+
const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
15963+
int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
15964+
const sycl::nd_item<3> &item_ct1, int &src1_row) {
15965+
int32_t iid1 = item_ct1.get_group(2);
15966+
int32_t id = item_ct1.get_group(1);
15967+
15968+
const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
15969+
15970+
if (row_id_i != i02) {
15971+
return;
15972+
}
15973+
15974+
const int64_t i11 = id % ne11;
15975+
const int64_t i12 = iid1;
15976+
15977+
if (item_ct1.get_local_id(2) == 0) {
15978+
src1_row =
15979+
dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
15980+
cur_src1_row, 1);
15981+
row_mapping[src1_row] = {id, iid1};
15982+
}
15983+
/*
15984+
DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
15985+
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
15986+
performance if there is no access to global memory.
15987+
*/
15988+
item_ct1.barrier();
15989+
15990+
const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
15991+
float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
15992+
15993+
#pragma unroll
15994+
for (int i = item_ct1.get_local_id(2); i < ne10;
15995+
i += item_ct1.get_local_range(2)) {
15996+
src1_row_contiguous[i] = src1_row_original[i];
15997+
}
15998+
}
15999+
16000+
__dpct_inline__ static void k_copy_dst_from_contiguous(
16001+
char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,
16002+
const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,
16003+
size_t nb2, const sycl::nd_item<3> &item_ct1) {
16004+
int32_t i = item_ct1.get_group(2);
16005+
16006+
const int32_t i1 = row_mapping[i].i1;
16007+
const int32_t i2 = row_mapping[i].i2;
16008+
16009+
const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
16010+
float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
16011+
16012+
#pragma unroll
16013+
for (int j = item_ct1.get_local_id(2); j < ne0;
16014+
j += item_ct1.get_local_range(2)) {
16015+
dst_row_original[j] = dst_row_contiguous[j];
16016+
}
16017+
}
16018+
1590216019
static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
1590316020
const ggml_tensor *src1,
1590416021
ggml_tensor *dst) try {
1590516022
GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT &&
1590616023
"mul_mat_id does not support split buffers");
1590716024
const ggml_tensor *ids = dst->src[2];
16025+
GGML_TENSOR_BINARY_OP_LOCALS
16026+
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
1590816027
const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
1590916028

15910-
const size_t nb11 = src1->nb[1];
15911-
const size_t nb1 = dst->nb[1];
15912-
15913-
const int32_t id = ((int32_t *)dst->op_params)[0];
15914-
const int32_t n_as = src0->ne[2];
16029+
const int64_t n_as = ne02;
16030+
const int64_t n_ids = ids->ne[0];
1591516031

1591616032
std::vector<char> ids_host(ggml_nbytes(ids));
15917-
const char *ids_dev = (const char *)ids->data;
16033+
const char * ids_dev = (const char *) ids->data;
1591816034

1591916035
SYCL_CHECK(CHECK_TRY_ERROR(
1592016036
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
@@ -15954,24 +16070,40 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
1595416070

1595516071
src0_row.ne[2] = 1;
1595616072
src0_row.ne[3] = 1;
15957-
src0_row.nb[3] = src0->nb[2];
15958-
15959-
if (src1->ne[1] == 1) {
15960-
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15961-
const int32_t row_id =
15962-
*(const int32_t *)(ids_host.data() + i01 * ids->nb[1] +
15963-
id * ids->nb[0]);
15964-
15965-
GGML_ASSERT(row_id >= 0 && row_id < n_as);
16073+
src0_row.nb[3] = nb02;
16074+
16075+
src1_row.ne[1] = 1;
16076+
src1_row.ne[2] = 1;
16077+
src1_row.ne[3] = 1;
16078+
src1_row.nb[2] = nb11;
16079+
src1_row.nb[3] = nb11;
16080+
16081+
dst_row.ne[1] = 1;
16082+
dst_row.ne[2] = 1;
16083+
dst_row.ne[3] = 1;
16084+
dst_row.nb[2] = nb1;
16085+
dst_row.nb[3] = nb1;
16086+
if (ne12 == 1) {
16087+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
16088+
for (int64_t id = 0; id < n_ids; id++) {
16089+
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
16090+
GGML_ASSERT(i02 >= 0 && i02 < n_as);
16091+
16092+
const int64_t i11 = id % ne11;
16093+
const int64_t i12 = iid1;
16094+
16095+
const int64_t i1 = id;
16096+
const int64_t i2 = i12;
1596616097

1596716098
src0_row_extra.data_device[g_main_device] =
15968-
src0_original + row_id * src0->nb[2];
16099+
src0_original + i02*nb02;
1596916100
src1_row_extra.data_device[g_main_device] =
15970-
src1_original + i01 * src1->nb[1];
16101+
src1_original + + i11*nb11 + i12*nb12;
1597116102
dst_row_extra.data_device[g_main_device] =
15972-
dst_original + i01 * dst->nb[1];
16103+
dst_original + i1*nb1 + i2*nb2;
1597316104

1597416105
ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
16106+
}
1597516107
}
1597616108
} else {
1597716109
sycl_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
@@ -15980,63 +16112,134 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
1598016112
src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
1598116113
dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
1598216114

15983-
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
16115+
for (int64_t i02 = 0; i02 < n_as; i02++) {
1598416116
int64_t num_src1_rows = 0;
15985-
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15986-
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
16117+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
16118+
for (int64_t id = 0; id < n_ids; id++) {
16119+
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
1598716120

15988-
if (row_id_i != row_id) {
15989-
continue;
15990-
}
16121+
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
1599116122

15992-
GGML_ASSERT(row_id >= 0 && row_id < n_as);
16123+
if (row_id_i != i02) {
16124+
continue;
16125+
}
1599316126

15994-
SYCL_CHECK(CHECK_TRY_ERROR(
15995-
stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11,
15996-
src1_original + i01 * nb11, nb11)));
15997-
num_src1_rows++;
16127+
num_src1_rows++;
16128+
}
1599816129
}
1599916130

1600016131
if (num_src1_rows == 0) {
1600116132
continue;
1600216133
}
1600316134

16004-
src0_row_extra.data_device[g_main_device] =
16005-
src0_original + row_id * src0->nb[2];
1600616135

16136+
sycl_pool_alloc<int> dev_cur_src1_row(1);
16137+
sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(num_src1_rows);
16138+
SYCL_CHECK(CHECK_TRY_ERROR(
16139+
stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
16140+
16141+
{
16142+
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
16143+
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
16144+
/*
16145+
DPCT1049:81: The work-group size passed to the SYCL kernel may
16146+
exceed the limit. To get the device limit, query
16147+
info::device::max_work_group_size. Adjust the work-group size if
16148+
needed.
16149+
*/
16150+
stream->submit([&](sycl::handler &cgh) {
16151+
sycl::local_accessor<int, 0> src1_row_acc_ct1(cgh);
16152+
16153+
char *__restrict src1_contiguous_get_ct1 =
16154+
src1_contiguous.get();
16155+
int *__restrict dev_cur_src1_row_get_ct2 =
16156+
dev_cur_src1_row.get();
16157+
mmid_row_mapping *__restrict dev_row_mapping_get_ct3 =
16158+
dev_row_mapping.get();
16159+
size_t ids_nb_ct6 = ids->nb[1];
16160+
size_t ids_nb_ct7 = ids->nb[0];
16161+
16162+
cgh.parallel_for(
16163+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
16164+
[=](sycl::nd_item<3> item_ct1) {
16165+
k_copy_src1_to_contiguous(
16166+
src1_original, src1_contiguous_get_ct1,
16167+
dev_cur_src1_row_get_ct2,
16168+
dev_row_mapping_get_ct3, ids_dev, i02,
16169+
ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
16170+
item_ct1, src1_row_acc_ct1);
16171+
});
16172+
});
16173+
/*
16174+
DPCT1010:196: SYCL uses exceptions to report errors and does not
16175+
use the error codes. The call was replaced with 0. You need to
16176+
rewrite this code.
16177+
*/
16178+
/*
16179+
DPCT1009:197: SYCL uses exceptions to report errors and does not
16180+
use the error codes. The call was replaced by a placeholder
16181+
string. You need to rewrite this code.
16182+
*/
16183+
SYCL_CHECK(0);
16184+
}
16185+
16186+
src0_row_extra.data_device[g_main_device] = src0_original + i02*nb02;
16187+
16188+
GGML_ASSERT(nb11 == sizeof(float)*ne10);
16189+
GGML_ASSERT(nb1 == sizeof(float)*ne0);
1600716190
src1_row.ne[1] = num_src1_rows;
16008-
dst_row.ne[1] = num_src1_rows;
1600916191

1601016192
src1_row.nb[1] = nb11;
1601116193
src1_row.nb[2] = num_src1_rows*nb11;
1601216194
src1_row.nb[3] = num_src1_rows*nb11;
1601316195

16196+
dst_row.ne[1] = num_src1_rows;
1601416197
dst_row.nb[1] = nb1;
1601516198
dst_row.nb[2] = num_src1_rows*nb1;
1601616199
dst_row.nb[3] = num_src1_rows*nb1;
1601716200

1601816201
ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
1601916202

16020-
num_src1_rows = 0;
16021-
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
16022-
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
16023-
16024-
if (row_id_i != row_id) {
16025-
continue;
16026-
}
16027-
16028-
GGML_ASSERT(row_id >= 0 && row_id < n_as);
16029-
16030-
SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
16031-
dst_original + i01 * nb1,
16032-
dst_contiguous.get() + num_src1_rows * nb1, nb1)));
16033-
num_src1_rows++;
16203+
{
16204+
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
16205+
sycl::range<3> grid_dims(1, 1, num_src1_rows);
16206+
/*
16207+
DPCT1049:82: The work-group size passed to the SYCL kernel may
16208+
exceed the limit. To get the device limit, query
16209+
info::device::max_work_group_size. Adjust the work-group size if
16210+
needed.
16211+
*/
16212+
stream->submit([&](sycl::handler &cgh) {
16213+
const char *__restrict dst_contiguous_get_ct1 =
16214+
dst_contiguous.get();
16215+
const mmid_row_mapping *__restrict dev_row_mapping_get_ct2 =
16216+
dev_row_mapping.get();
16217+
16218+
cgh.parallel_for(
16219+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
16220+
[=](sycl::nd_item<3> item_ct1) {
16221+
k_copy_dst_from_contiguous(dst_original,
16222+
dst_contiguous_get_ct1,
16223+
dev_row_mapping_get_ct2,
16224+
ne0, nb1, nb2, item_ct1);
16225+
});
16226+
});
16227+
/*
16228+
DPCT1010:198: SYCL uses exceptions to report errors and does not
16229+
use the error codes. The call was replaced with 0. You need to
16230+
rewrite this code.
16231+
*/
16232+
/*
16233+
DPCT1009:199: SYCL uses exceptions to report errors and does not
16234+
use the error codes. The call was replaced by a placeholder
16235+
string. You need to rewrite this code.
16236+
*/
16237+
SYCL_CHECK(0);
1603416238
}
1603516239
}
16036-
}
16037-
16038-
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
16039-
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
16240+
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
16241+
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
16242+
}
1604016243
}
1604116244
}
1604216245
catch (sycl::exception const &exc) {
@@ -17020,10 +17223,9 @@ GGML_CALL static const char * ggml_backend_sycl_split_buffer_get_name(ggml_backe
1702017223
UNUSED(buffer);
1702117224
}
1702217225

17023-
// unused at the moment
17024-
//static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
17025-
// return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
17026-
//}
17226+
static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
17227+
return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
17228+
}
1702717229

1702817230
GGML_CALL static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1702917231
ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;

0 commit comments

Comments
 (0)