Skip to content

Commit 19af285

Browse files
committed
remove duplicate buft initialization
1 parent 0915761 commit 19af285

File tree

1 file changed

+16
-33
lines changed

1 file changed

+16
-33
lines changed

ggml-sycl.cpp

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12615,6 +12615,9 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
1261512615
};
1261612616

1261712617
ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
12618+
static std::mutex mutex;
12619+
std::lock_guard<std::mutex> lock(mutex);
12620+
1261812621
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
1261912622

1262012623
if (device>=ggml_sycl_info().device_count or device<0) {
@@ -12640,31 +12643,6 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
1264012643
return &ggml_backend_sycl_buffer_types[device];
1264112644
}
1264212645

12643-
ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
12644-
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
12645-
12646-
int device = ctx->device;
12647-
if (device>=ggml_sycl_info().device_count or device<0) {
12648-
printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
12649-
device, ggml_sycl_info().device_count-1);
12650-
GGML_ASSERT(device<ggml_sycl_info().device_count);
12651-
}
12652-
static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];
12653-
12654-
static bool ggml_backend_sycl_buffer_type_initialized = false;
12655-
12656-
if (!ggml_backend_sycl_buffer_type_initialized) {
12657-
for (int i = 0; i < ggml_sycl_info().device_count; i++) {
12658-
ggml_backend_sycl_buffer_types[i] = {
12659-
/* .iface = */ ggml_backend_sycl_buffer_type_interface,
12660-
/* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), ctx->stream(i, 0)},
12661-
};
12662-
}
12663-
ggml_backend_sycl_buffer_type_initialized = true;
12664-
}
12665-
return &ggml_backend_sycl_buffer_types[device];
12666-
}
12667-
1266812646
// sycl split buffer type
1266912647
static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split, int id) {
1267012648
const int64_t nrows = ggml_nrows(tensor);
@@ -13016,6 +12994,9 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface
1301612994
};
1301712995

1301812996
GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) {
12997+
static std::mutex mutex;
12998+
std::lock_guard<std::mutex> lock(mutex);
12999+
1301913000
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n");
1302013001
ggml_check_sycl();
1302113002
// FIXME: this is not thread safe
@@ -13123,16 +13104,17 @@ GGML_CALL static void ggml_backend_sycl_free(ggml_backend_t backend) {
1312313104

1312413105
GGML_CALL static ggml_backend_buffer_type_t ggml_backend_sycl_get_default_buffer_type(ggml_backend_t backend) {
1312513106
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
13126-
return ggml_backend_sycl_buffer_type(sycl_ctx);
13107+
return ggml_backend_sycl_buffer_type(sycl_ctx->device);
1312713108
}
1312813109

1312913110
GGML_CALL static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
1313013111
ggml_tensor *tensor,
1313113112
const void *data, size_t offset,
1313213113
size_t size) try {
1313313114
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
13134-
GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && "unsupported buffer type");
13135-
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
13115+
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
13116+
13117+
GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
1313613118
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
1313713119
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
1313813120
(char *)tensor->data + offset, data, size).wait()));
@@ -13148,8 +13130,9 @@ GGML_CALL static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
1314813130
void *data, size_t offset,
1314913131
size_t size) try {
1315013132
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
13151-
GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && "unsupported buffer type");
13152-
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
13133+
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
13134+
13135+
GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
1315313136
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
1315413137
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
1315513138
data, (const char *)tensor->data + offset, size).wait()));
@@ -13164,7 +13147,7 @@ GGML_CALL static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
1316413147
const ggml_tensor *src,
1316513148
ggml_tensor *dst) try {
1316613149
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
13167-
if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && ggml_backend_buffer_is_sycl(src->buffer)) {
13150+
if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && ggml_backend_buffer_is_sycl(src->buffer)) {
1316813151
/*
1316913152
DPCT1009:215: SYCL uses exceptions to report errors and does not use the
1317013153
error codes. The original code was commented out and a warning string
@@ -13208,10 +13191,10 @@ GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t back
1320813191
continue;
1320913192
}
1321013193
#ifndef NDEBUG
13211-
assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx));
13194+
assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
1321213195
for (int j = 0; j < GGML_MAX_SRC; j++) {
1321313196
if (node->src[j] != nullptr) {
13214-
assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx));
13197+
assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
1321513198
}
1321613199
}
1321713200
#endif

0 commit comments

Comments
 (0)