Skip to content

Commit 7369e54

Browse files
committed
Add back ggml_sycl_set_device to kernels
1 parent 0ae9a07 commit 7369e54

20 files changed

+48
-2
lines changed

ggml/src/ggml-sycl/argmax.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ static void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * d
5858
const int64_t nrows = ggml_nrows(dst->src[0]);
5959

6060
dpct::queue_ptr main_stream = ctx.stream();
61+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
6162
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
6263
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
6364
argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);

ggml/src/ggml-sycl/argsort.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
111111

112112
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
113113
dpct::queue_ptr main_stream = ctx.stream();
114+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
114115
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
115116
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
116117

ggml/src/ggml-sycl/binbcast.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
237237
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
238238
void * dst_dd = static_cast<void *>(dst->data);
239239
const dpct::queue_ptr main_stream = ctx.stream();
240+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
240241

241242
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd,
242243
main_stream);
@@ -250,6 +251,7 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
250251
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
251252
void * dst_dd = static_cast<void *>(dst->data);
252253
const dpct::queue_ptr main_stream = ctx.stream();
254+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
253255

254256
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd,
255257
main_stream);
@@ -263,6 +265,7 @@ inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
263265
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
264266
void * dst_dd = static_cast<void *>(dst->data);
265267
const dpct::queue_ptr main_stream = ctx.stream();
268+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
266269

267270
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd,
268271
main_stream);
@@ -276,6 +279,7 @@ inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
276279
const void * src1_dd = static_cast<void *>(dst->src[1]->data);
277280
void * dst_dd = static_cast<void *>(dst->data);
278281
const dpct::queue_ptr main_stream = ctx.stream();
282+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
279283

280284
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(dst->src[0], dst->src[1], dst, src0_dd, src1_dd, dst_dd,
281285
main_stream);
@@ -288,6 +292,7 @@ inline void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * d
288292
const void * src0_d = static_cast<void *>(dst->src[0]->data);
289293
void * dst_d = static_cast<void *>(dst->data);
290294
dpct::queue_ptr main_stream = ctx.stream();
295+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
291296

292297
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(dst, dst->src[0], dst, nullptr, src0_d, dst_d, main_stream);
293298
} catch (const sycl::exception & exc) {

ggml/src/ggml-sycl/clamp.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * ds
3030
memcpy(&min, dst->op_params, sizeof(float));
3131
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
3232
const dpct::queue_ptr main_stream = ctx.stream();
33+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
3334
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
3435
float * dst_dd = static_cast<float *>(dst->data);
3536

ggml/src/ggml-sycl/concat.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ static void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor * d
162162
const ggml_tensor *src0 = dst->src[0];
163163
const ggml_tensor *src1 = dst->src[1];
164164
queue_ptr stream = ctx.stream();
165+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
165166

166167
const int32_t dim = ((int32_t *)dst->op_params)[0];
167168

ggml/src/ggml-sycl/conv.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor
7979

8080
float * dst_d = (float *)dst->data;
8181
dpct::queue_ptr stream = ctx.stream();
82+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
8283

8384
GGML_ASSERT(src0->type == GGML_TYPE_F32);
8485
GGML_ASSERT( dst->type == GGML_TYPE_F32);

ggml/src/ggml-sycl/diagmask.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_ten
3737

3838
const int n_past = ((int32_t *) dst->op_params)[0];
3939
dpct::queue_ptr main_stream = ctx.stream();
40+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
4041
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
4142
float * dst_dd = static_cast<float *>(dst->data);
4243

ggml/src/ggml-sycl/element_wise.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
514514
GGML_ASSERT(dst->type == GGML_TYPE_F32);
515515
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
516516
const dpct::queue_ptr main_stream = ctx.stream();
517+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
517518
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
518519
float * dst_dd = static_cast<float *>(dst->data);
519520

@@ -526,6 +527,7 @@ inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
526527
GGML_ASSERT(dst->type == GGML_TYPE_F32);
527528
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
528529
const dpct::queue_ptr main_stream = ctx.stream();
530+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
529531
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
530532
float * dst_dd = static_cast<float *>(dst->data);
531533

@@ -538,6 +540,7 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
538540
GGML_ASSERT(dst->type == GGML_TYPE_F32);
539541
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
540542
const dpct::queue_ptr main_stream = ctx.stream();
543+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
541544
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
542545
float * dst_dd = static_cast<float *>(dst->data);
543546

@@ -551,6 +554,7 @@ inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
551554
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
552555

553556
const dpct::queue_ptr main_stream = ctx.stream();
557+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
554558
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
555559
float * dst_dd = static_cast<float *>(dst->data);
556560
tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
@@ -562,6 +566,7 @@ inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
562566
GGML_ASSERT(dst->type == GGML_TYPE_F32);
563567
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
564568
const dpct::queue_ptr main_stream = ctx.stream();
569+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
565570
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
566571
float * dst_dd = static_cast<float *>(dst->data);
567572
relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
@@ -573,6 +578,7 @@ inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tenso
573578
GGML_ASSERT(dst->type == GGML_TYPE_F32);
574579
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
575580
const dpct::queue_ptr main_stream = ctx.stream();
581+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
576582
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
577583
float * dst_dd = static_cast<float *>(dst->data);
578584
hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
@@ -585,6 +591,7 @@ inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor
585591
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
586592

587593
const dpct::queue_ptr main_stream = ctx.stream();
594+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
588595
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
589596
float * dst_dd = static_cast<float *>(dst->data);
590597

@@ -597,6 +604,7 @@ inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
597604
GGML_ASSERT(dst->type == GGML_TYPE_F32);
598605
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
599606
const dpct::queue_ptr main_stream = ctx.stream();
607+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
600608
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
601609
float * dst_dd = static_cast<float *>(dst->data);
602610
exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
@@ -608,6 +616,7 @@ inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
608616
GGML_ASSERT( dst->type == GGML_TYPE_F32);
609617
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
610618
const dpct::queue_ptr main_stream = ctx.stream();
619+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
611620
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
612621
float * dst_dd = static_cast<float *>(dst->data);
613622

@@ -620,6 +629,7 @@ inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *d
620629
GGML_ASSERT(dst->type == GGML_TYPE_F32);
621630
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
622631
const dpct::queue_ptr main_stream = ctx.stream();
632+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
623633
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
624634
float * dst_dd = static_cast<float *>(dst->data);
625635

@@ -632,6 +642,7 @@ inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst
632642
GGML_ASSERT(dst->type == GGML_TYPE_F32);
633643
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
634644
const dpct::queue_ptr main_stream = ctx.stream();
645+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
635646
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
636647
float * dst_dd = static_cast<float *>(dst->data);
637648
sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
@@ -643,6 +654,7 @@ inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
643654
GGML_ASSERT(dst->type == GGML_TYPE_F32);
644655
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
645656
const dpct::queue_ptr main_stream = ctx.stream();
657+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
646658
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
647659
float * dst_dd = static_cast<float *>(dst->data);
648660

@@ -655,6 +667,7 @@ inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
655667
GGML_ASSERT(dst->type == GGML_TYPE_F32);
656668
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
657669
const dpct::queue_ptr main_stream = ctx.stream();
670+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
658671
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
659672
float * dst_dd = static_cast<float *>(dst->data);
660673

@@ -669,6 +682,7 @@ inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
669682
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
670683
float * dst_dd = static_cast<float *>(dst->data);
671684
dpct::queue_ptr main_stream = ctx.stream();
685+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
672686

673687
step_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
674688
}
@@ -681,6 +695,7 @@ inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
681695
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
682696
float * dst_dd = static_cast<float *>(dst->data);
683697
dpct::queue_ptr main_stream = ctx.stream();
698+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
684699

685700
neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
686701
}
@@ -697,6 +712,7 @@ inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor
697712
float * dst_dd = static_cast<float *>(dst->data);
698713

699714
dpct::queue_ptr main_stream = ctx.stream();
715+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
700716

701717
leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), negative_slope, main_stream);
702718
}
@@ -709,6 +725,7 @@ inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
709725
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
710726
float * dst_dd = static_cast<float *>(dst->data);
711727
dpct::queue_ptr main_stream = ctx.stream();
728+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
712729

713730
sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
714731
}
@@ -727,6 +744,7 @@ inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor *
727744
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
728745
float * dst_dd = static_cast<float *>(dst->data);
729746
dpct::queue_ptr main_stream = ctx.stream();
747+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
730748

731749
upscale_f32_sycl(src0_dd, dst_dd, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], dst->src[0]->nb[3],
732750
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
@@ -743,6 +761,7 @@ inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
743761
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
744762
float * dst_dd = static_cast<float *>(dst->data);
745763
dpct::queue_ptr main_stream = ctx.stream();
764+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
746765

747766
pad_f32_sycl(src0_dd, dst_dd,
748767
dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2],
@@ -760,6 +779,7 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx,
760779
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
761780

762781
const dpct::queue_ptr main_stream = ctx.stream();
782+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
763783
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
764784
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
765785
float * dst_dd = static_cast<float *>(dst->data);

ggml/src/ggml-sycl/getrows.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
8484
float * dst_dd = static_cast<float *>(dst->data);
8585

8686
dpct::queue_ptr stream = ctx.stream();
87+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
8788

8889
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
8990
k_get_rows<qk, qr, dq>(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, s3, nb01, nb02, nb03, s10, s11, s12,
@@ -113,6 +114,7 @@ template <typename src0_t> static void get_rows_sycl_float(ggml_backend_sycl_con
113114
float * dst_dd = static_cast<float *>(dst->data);
114115

115116
dpct::queue_ptr stream = ctx.stream();
117+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
116118

117119
{
118120
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3081,7 +3081,7 @@ bool ggml_backend_is_sycl(ggml_backend_t backend) {
30813081
}
30823082

30833083
int ggml_backend_sycl_get_device_count() {
3084-
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n");
3084+
// GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count:\n");
30853085
return ggml_sycl_info().device_count;
30863086
}
30873087

ggml/src/ggml-sycl/gla.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor
8888
const int64_t H = dst->src[0]->ne[1];
8989

9090
dpct::queue_ptr stream = ctx.stream();
91+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
9192
GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
9293
GGML_ASSERT(C % H == 0);
9394
GGML_ASSERT(C / H == 64 || C / H == 128);

ggml/src/ggml-sycl/im2col.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ static void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * d
112112
const int64_t batch = dst->src[1]->ne[3];
113113
const size_t batch_offset = dst->src[1]->nb[3] / 4; // nb is byte offset, src is type float32
114114
dpct::queue_ptr main_stream = ctx.stream();
115+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
115116

116117
if (dst->type == GGML_TYPE_F16) {
117118
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);

ggml/src/ggml-sycl/norm.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ static void ggml_sycl_op_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst
326326
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
327327
float * dst_dd = static_cast<float *>(dst->data);
328328
dpct::queue_ptr main_stream = ctx.stream();
329+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
329330

330331
norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
331332
} catch (const sycl::exception & exc) {
@@ -348,6 +349,7 @@ static void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor*
348349
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
349350
float * dst_dd = static_cast<float *>(dst->data);
350351
dpct::queue_ptr main_stream = ctx.stream();
352+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
351353
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, dst->src[0]->ne[0] * dst->src[0]->ne[1] * dst->src[0]->ne[2], main_stream, ctx.device);
352354
} catch (const sycl::exception & exc) {
353355
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
@@ -368,6 +370,7 @@ static void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor *
368370
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
369371
float * dst_dd = static_cast<float *>(dst->data);
370372
dpct::queue_ptr main_stream = ctx.stream();
373+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
371374

372375
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
373376
} catch (const sycl::exception & exc) {

ggml/src/ggml-sycl/outprod.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
1717

1818
// Get SYCL queue
1919
dpct::queue_ptr stream = ctx.stream();
20+
// set device
21+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2022

2123
// Dimension checks
2224
GGML_ASSERT(ne01 == ne11); // Inner dimensions must match

ggml/src/ggml-sycl/pool2d.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * d
9393
const int parallel_elements = N * OC * OH * OW;
9494
const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE;
9595
dpct::queue_ptr main_stream = ctx.stream();
96+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
9697
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
9798
float * dst_dd = static_cast<float *>(dst->data);
9899
sycl::range<3> block_nums(1, 1, num_blocks);

ggml/src/ggml-sycl/rope.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ static void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst
236236
rope_corr_dims corr_dims;
237237
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
238238
dpct::queue_ptr main_stream = ctx.stream();
239+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
239240

240241
// compute
241242
if (is_neox) {

ggml/src/ggml-sycl/scale.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds
2929
float * dst_dd = static_cast<float *>(dst->data);
3030

3131
dpct::queue_ptr main_stream = ctx.stream();
32+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
3233

3334
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
3435
/*

0 commit comments

Comments
 (0)