Skip to content

fix mul_mat_id() for new input, make the ut pass #6682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 50 additions & 46 deletions ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15996,73 +15996,76 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
const ggml_tensor *src1,
ggml_tensor *dst) try {
#if 0
ggml_sycl_mul_mat_id_sycl(dst);
// TODO: mmq/mmv support
#endif
GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT &&
"mul_mat_id does not support split buffers");
const ggml_tensor *ids = dst->src[2];
const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];

const int64_t nb11 = src1->nb[1];
const int64_t nb1 = dst->nb[1];
const size_t nb11 = src1->nb[1];
const size_t nb1 = dst->nb[1];

const struct ggml_tensor * ids = src0;
const int32_t id = ((int32_t *) dst->op_params)[0];
const int32_t n_as = ((int32_t *) dst->op_params)[1];
const int32_t id = ((int32_t *)dst->op_params)[0];
const int32_t n_as = src0->ne[2];

std::vector<char> ids_host(ggml_nbytes(ids));
const char *ids_dev = (const char *)ids->data;

const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];

if (ids->backend == GGML_BACKEND_TYPE_GPU) {
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
SYCL_CHECK(CHECK_TRY_ERROR(
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)).wait()));
// SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
} else {
memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
}
SYCL_CHECK(CHECK_TRY_ERROR(
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));

const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
const ggml_tensor_extra_gpu *src0_extra =
(const ggml_tensor_extra_gpu *)src0->extra;
const ggml_tensor_extra_gpu *src1_extra =
(const ggml_tensor_extra_gpu *)src1->extra;
const ggml_tensor_extra_gpu *dst_extra =
(const ggml_tensor_extra_gpu *)dst->extra;

ggml_tensor_extra_gpu src0_row_extra;
ggml_tensor_extra_gpu src1_row_extra;
ggml_tensor_extra_gpu dst_row_extra;

ggml_tensor src0_row = *src0;
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;

src1_row.backend = GGML_BACKEND_TYPE_GPU;
dst_row.backend = GGML_BACKEND_TYPE_GPU;

src0_row.extra = &src0_row_extra;
src1_row.extra = &src1_row_extra;
dst_row.extra = &dst_row_extra;

char * src1_original = src1->backend == GGML_BACKEND_TYPE_CPU ?
(char *) src1->data : (char *) src1_extra->data_device[g_main_device];
char * dst_original = dst->backend == GGML_BACKEND_TYPE_CPU ?
(char *) dst->data : (char *) dst_extra->data_device[g_main_device];
char *src0_original = src1->backend == GGML_BACKEND_TYPE_CPU
? (char *)src0->data
: (char *)src0_extra->data_device[g_main_device];
char *src1_original = src1->backend == GGML_BACKEND_TYPE_CPU
? (char *)src1->data
: (char *)src1_extra->data_device[g_main_device];
char *dst_original = dst->backend == GGML_BACKEND_TYPE_CPU
? (char *)dst->data
: (char *)dst_extra->data_device[g_main_device];

if (src1->ne[1] == 1) {
GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU);
GGML_ASSERT(dst->backend == GGML_BACKEND_TYPE_GPU);
src0_row.ne[2] = 1;
src0_row.ne[3] = 1;
src0_row.nb[3] = src0->nb[2];

if (src1->ne[1] == 1) {
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
//int32_t row_id;
//SYCL_CHECK(syclMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), syclMemcpyDeviceToHost, g_syclStreams[g_main_device][0]));
//SYCL_CHECK(syclStreamSynchronize(g_syclStreams[g_main_device][0]));

const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
const int32_t row_id =
*(const int32_t *)(ids_host.data() + i01 * ids->nb[1] +
id * ids->nb[0]);

GGML_ASSERT(row_id >= 0 && row_id < n_as);

const struct ggml_tensor * src0_row = dst->src[row_id + 2];
src0_row_extra.data_device[g_main_device] =
src0_original + row_id * src0->nb[2];
src1_row_extra.data_device[g_main_device] =
src1_original + i01 * src1->nb[1];
dst_row_extra.data_device[g_main_device] =
dst_original + i01 * dst->nb[1];

src1_row_extra.data_device[g_main_device] = src1_original + i01*src1->nb[1];
src1_row.data = (char *) src1->data + i01*src1->nb[1]; // TODO why is this set?

dst_row_extra.data_device[g_main_device] = dst_original + i01*dst->nb[1];
dst_row.data = (char *) dst->data + i01*dst->nb[1]; // TODO why is this set?

ggml_sycl_mul_mat(src0_row, &src1_row, &dst_row);
ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
}
} else {
sycl_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
Expand All @@ -16072,8 +16075,6 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
dst_row_extra.data_device[g_main_device] = dst_contiguous.get();

for (int32_t row_id = 0; row_id < n_as; ++row_id) {
const struct ggml_tensor * src0_row = dst->src[row_id + 2];

int64_t num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
Expand All @@ -16086,14 +16087,17 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,

SYCL_CHECK(CHECK_TRY_ERROR(
stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11,
src1_original + i01 * nb11, nb11).wait()));
src1_original + i01 * nb11, nb11)));
num_src1_rows++;
}

if (num_src1_rows == 0) {
continue;
}

src0_row_extra.data_device[g_main_device] =
src0_original + row_id * src0->nb[2];

src1_row.ne[1] = num_src1_rows;
dst_row.ne[1] = num_src1_rows;

Expand All @@ -16105,7 +16109,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
dst_row.nb[2] = num_src1_rows*nb1;
dst_row.nb[3] = num_src1_rows*nb1;

ggml_sycl_mul_mat(src0_row, &src1_row, &dst_row);
ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);

num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
Expand All @@ -16119,7 +16123,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,

SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
dst_original + i01 * nb1,
dst_contiguous.get() + num_src1_rows * nb1, nb1).wait()));
dst_contiguous.get() + num_src1_rows * nb1, nb1)));
num_src1_rows++;
}
}
Expand Down