Skip to content

Commit 8308f98

Browse files
authored
sycl: add usage of enqueue_functions extension (#14244)
* Add header and namespace to use enqueue_functions extension * Convert submit and parallel_for to use new extension in convert.cpp * Convert submit and parallel_for to use extension in ggml-sycl.cpp * Convert submit and parallel_for to use extension in gla.cpp * Convert submit and parallel_for in mmq.cpp * Convert submit and parallel_for in mmvq.cpp * Convert submit and parallel_for in remaining files * Convert all simple parallel_for to nd_launch from enqueue_functions extension * Wrapping extension in general function Create a general function that enable the enqueue_functions extension if it is enable in the compiler, otherwise call the general SYCL function to launch kernels. --------- Signed-off-by: nscipione <[email protected]>
1 parent 6369be0 commit 8308f98

19 files changed

+750
-986
lines changed

ggml/src/ggml-sycl/binbcast.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ struct bin_bcast_sycl {
225225
dpct::has_capability_or_fail(stream->get_device(),
226226
{sycl::aspect::fp16});
227227

228-
stream->parallel_for(
229-
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
230-
sycl::range<3>(1, 1, block_size),
228+
sycl_parallel_for(
229+
stream,
230+
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * sycl::range<3>(1, 1, block_size),
231231
sycl::range<3>(1, 1, block_size)),
232232
[=](sycl::nd_item<3> item_ct1) {
233233
k_bin_bcast_unravel<bin_op>(
@@ -246,9 +246,8 @@ struct bin_bcast_sycl {
246246
dpct::has_capability_or_fail(stream->get_device(),
247247
{sycl::aspect::fp16});
248248

249-
stream->parallel_for(
250-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
251-
[=](sycl::nd_item<3> item_ct1) {
249+
sycl_parallel_for(
250+
stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
252251
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
253252
ne2, ne3, ne10, ne11, ne12, ne13,
254253
s1, s2, s3, s01, s02, s03, s11, s12, s13,

ggml/src/ggml-sycl/concat.cpp

Lines changed: 28 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -89,33 +89,24 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
8989
sycl::range<3> gridDim(ne2, ne1, num_blocks);
9090
switch (dim) {
9191
case 0:
92-
stream->parallel_for(
93-
sycl::nd_range<3>(gridDim *
94-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
95-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
96-
[=](sycl::nd_item<3> item_ct1) {
97-
concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1);
98-
});
99-
break;
92+
sycl_parallel_for(stream,
93+
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
94+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
95+
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1); });
96+
break;
10097
case 1:
101-
stream->parallel_for(
102-
sycl::nd_range<3>(gridDim *
103-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
104-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
105-
[=](sycl::nd_item<3> item_ct1) {
106-
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
107-
});
108-
break;
98+
sycl_parallel_for(stream,
99+
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
100+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
101+
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1); });
102+
break;
109103
// dim >=2 will be dispatched to the default path
110104
default:
111-
stream->parallel_for(
112-
sycl::nd_range<3>(gridDim *
113-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
114-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
115-
[=](sycl::nd_item<3> item_ct1) {
116-
concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1);
117-
});
118-
break;
105+
sycl_parallel_for(stream,
106+
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
107+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
108+
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1); });
109+
break;
119110
}
120111
}
121112

@@ -129,33 +120,29 @@ static void concat_f32_sycl_non_cont(
129120
int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
130121
uint64_t nb3, int32_t dim) {
131122
sycl::range<3> gridDim(ne3, ne2, ne1);
132-
stream->parallel_for(
133-
sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)),
134-
[=](sycl::nd_item<3> item_ct1) {
135-
int64_t i3 = item_ct1.get_group(0);
136-
int64_t i2 = item_ct1.get_group(1);
137-
int64_t i1 = item_ct1.get_group(2);
123+
sycl_parallel_for(stream, sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
124+
int64_t i3 = item_ct1.get_group(0);
125+
int64_t i2 = item_ct1.get_group(1);
126+
int64_t i1 = item_ct1.get_group(2);
138127

139-
int64_t o[4] = {0, 0, 0, 0};
140-
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
128+
int64_t o[4] = { 0, 0, 0, 0 };
129+
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
141130

142-
const float *x;
131+
const float * x;
143132

144-
for (int i0 = item_ct1.get_local_id(2); i0 < ne0;
145-
i0 += item_ct1.get_local_range(2)) {
133+
for (int i0 = item_ct1.get_local_id(2); i0 < ne0; i0 += item_ct1.get_local_range(2)) {
146134
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
147-
x = (const float *)(src0 + (i3)*nb03 + (i2)*nb02 + (i1)*nb01 +
148-
(i0)*nb00);
135+
x = (const float *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00);
149136
} else {
150-
x = (const float *)(src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 +
151-
(i1 - o[1]) * nb11 + (i0 - o[0]) * nb10);
137+
x = (const float *) (src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + (i1 - o[1]) * nb11 +
138+
(i0 - o[0]) * nb10);
152139
}
153140

154141
float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
155142

156143
*y = *x;
157-
}
158-
});
144+
}
145+
});
159146
}
160147

161148
void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {

ggml/src/ggml-sycl/conv.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,10 @@ static void conv_transpose_1d_f32_f32_sycl(
5959
const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
6060
const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
6161
const sycl::range<3> block_nums(1, 1, num_blocks);
62-
stream->parallel_for(
63-
sycl::nd_range<3>(
64-
block_nums * block_dims, block_dims),
65-
[=](sycl::nd_item<3> item_ct1) {
66-
conv_transpose_1d_kernel(
67-
s0, output_size,
68-
src0_ne0, src0_ne1, src0_ne2,
69-
src1_ne0, dst_ne0,
70-
src0, src1, dst, item_ct1);
71-
});
62+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
63+
conv_transpose_1d_kernel(s0, output_size, src0_ne0, src0_ne1, src0_ne2, src1_ne0, dst_ne0, src0, src1, dst,
64+
item_ct1);
65+
});
7266
}
7367

7468
void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {

0 commit comments

Comments
 (0)