Skip to content

Commit 8ef8fe1

Browse files
committed
remove duplicate buft initialization
1 parent 40520b1 commit 8ef8fe1

File tree

1 file changed

+16
-33
lines changed

1 file changed

+16
-33
lines changed

ggml-sycl.cpp

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12675,6 +12675,9 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
1267512675
};
1267612676

1267712677
ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
12678+
static std::mutex mutex;
12679+
std::lock_guard<std::mutex> lock(mutex);
12680+
1267812681
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
1267912682

1268012683
if (device>=ggml_sycl_info().device_count or device<0) {
@@ -12700,31 +12703,6 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
1270012703
return &ggml_backend_sycl_buffer_types[device];
1270112704
}
1270212705

12703-
ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
12704-
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
12705-
12706-
int device = ctx->device;
12707-
if (device>=ggml_sycl_info().device_count or device<0) {
12708-
printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
12709-
device, ggml_sycl_info().device_count-1);
12710-
GGML_ASSERT(device<ggml_sycl_info().device_count);
12711-
}
12712-
static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];
12713-
12714-
static bool ggml_backend_sycl_buffer_type_initialized = false;
12715-
12716-
if (!ggml_backend_sycl_buffer_type_initialized) {
12717-
for (int i = 0; i < ggml_sycl_info().device_count; i++) {
12718-
ggml_backend_sycl_buffer_types[i] = {
12719-
/* .iface = */ ggml_backend_sycl_buffer_type_interface,
12720-
/* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), ctx->stream(i, 0)},
12721-
};
12722-
}
12723-
ggml_backend_sycl_buffer_type_initialized = true;
12724-
}
12725-
return &ggml_backend_sycl_buffer_types[device];
12726-
}
12727-
1272812706
// sycl split buffer type
1272912707
static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split, int id) {
1273012708
const int64_t nrows = ggml_nrows(tensor);
@@ -13076,6 +13054,9 @@ static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface
1307613054
};
1307713055

1307813056
GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) {
13057+
static std::mutex mutex;
13058+
std::lock_guard<std::mutex> lock(mutex);
13059+
1307913060
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n");
1308013061
ggml_check_sycl();
1308113062
// FIXME: this is not thread safe
@@ -13183,16 +13164,17 @@ GGML_CALL static void ggml_backend_sycl_free(ggml_backend_t backend) {
1318313164

1318413165
GGML_CALL static ggml_backend_buffer_type_t ggml_backend_sycl_get_default_buffer_type(ggml_backend_t backend) {
1318513166
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
13186-
return ggml_backend_sycl_buffer_type(sycl_ctx);
13167+
return ggml_backend_sycl_buffer_type(sycl_ctx->device);
1318713168
}
1318813169

1318913170
GGML_CALL static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
1319013171
ggml_tensor *tensor,
1319113172
const void *data, size_t offset,
1319213173
size_t size) try {
1319313174
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
13194-
GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && "unsupported buffer type");
13195-
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
13175+
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
13176+
13177+
GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
1319613178
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
1319713179
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
1319813180
(char *)tensor->data + offset, data, size).wait()));
@@ -13208,8 +13190,9 @@ GGML_CALL static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
1320813190
void *data, size_t offset,
1320913191
size_t size) try {
1321013192
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
13211-
GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && "unsupported buffer type");
13212-
GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU);
13193+
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
13194+
13195+
GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
1321313196
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
1321413197
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
1321513198
data, (const char *)tensor->data + offset, size).wait()));
@@ -13224,7 +13207,7 @@ GGML_CALL static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
1322413207
const ggml_tensor *src,
1322513208
ggml_tensor *dst) try {
1322613209
ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
13227-
if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx) && ggml_backend_buffer_is_sycl(src->buffer)) {
13210+
if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && ggml_backend_buffer_is_sycl(src->buffer)) {
1322813211
/*
1322913212
DPCT1009:215: SYCL uses exceptions to report errors and does not use the
1323013213
error codes. The original code was commented out and a warning string
@@ -13268,10 +13251,10 @@ GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t back
1326813251
continue;
1326913252
}
1327013253
#ifndef NDEBUG
13271-
assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx));
13254+
assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
1327213255
for (int j = 0; j < GGML_MAX_SRC; j++) {
1327313256
if (node->src[j] != nullptr) {
13274-
assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx));
13257+
assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
1327513258
}
1327613259
}
1327713260
#endif

0 commit comments

Comments
 (0)