Skip to content

Commit 4731850

Browse files
committed
SYCL: refactor move to a separate file
1 parent b43d89e commit 4731850

File tree

7 files changed

+380
-370
lines changed

7 files changed

+380
-370
lines changed

ggml/src/ggml-sycl/backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef GGML_SYCL_BACKEND_HPP
1414
#define GGML_SYCL_BACKEND_HPP
1515

16+
#include "binbcast.hpp"
1617
#include "concat.hpp"
1718
#include "common.hpp"
1819
#include "conv.hpp"

ggml/src/ggml-sycl/binbcast.cpp

Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
#include "binbcast.hpp"
2+
#include <sycl/sycl.hpp>
3+
4+
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
5+
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
6+
int ne0, int ne1, int ne2, int ne3,
7+
int ne10, int ne11, int ne12, int ne13,
8+
/*int s0, */ int s1, int s2, int s3,
9+
/*int s00,*/ int s01, int s02, int s03,
10+
/*int s10,*/ int s11, int s12, int s13,
11+
const sycl::nd_item<3> &item_ct1) {
12+
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
13+
item_ct1.get_local_id(2);
14+
const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
15+
item_ct1.get_local_id(1));
16+
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
17+
item_ct1.get_local_id(0)) /
18+
ne3;
19+
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
20+
item_ct1.get_local_id(0)) %
21+
ne3;
22+
23+
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
24+
return;
25+
}
26+
27+
const int i11 = i1 % ne11;
28+
const int i12 = i2 % ne12;
29+
const int i13 = i3 % ne13;
30+
31+
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
32+
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
33+
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
34+
35+
const src0_t * src0_row = src0 + i_src0;
36+
const src1_t * src1_row = src1 + i_src1;
37+
dst_t * dst_row = dst + i_dst;
38+
39+
for (int i0 = i0s; i0 < ne0;
40+
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
41+
const int i10 = i0 % ne10;
42+
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
43+
}
44+
}
45+
46+
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
47+
static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
48+
int ne0, int ne1, int ne2, int ne3,
49+
int ne10, int ne11, int ne12, int ne13,
50+
/*int s0, */ int s1, int s2, int s3,
51+
/*int s00,*/ int s01, int s02, int s03,
52+
/*int s10,*/ int s11, int s12, int s13,
53+
const sycl::nd_item<3> &item_ct1) {
54+
55+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
56+
item_ct1.get_local_id(2);
57+
58+
const int i3 = i/(ne2*ne1*ne0);
59+
const int i2 = (i/(ne1*ne0)) % ne2;
60+
const int i1 = (i/ne0) % ne1;
61+
const int i0 = i % ne0;
62+
63+
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
64+
return;
65+
}
66+
67+
const int i11 = i1 % ne11;
68+
const int i12 = i2 % ne12;
69+
const int i13 = i3 % ne13;
70+
71+
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
72+
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
73+
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
74+
75+
const src0_t * src0_row = src0 + i_src0;
76+
const src1_t * src1_row = src1 + i_src1;
77+
dst_t * dst_row = dst + i_dst;
78+
79+
const int i10 = i0 % ne10;
80+
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
81+
}
82+
83+
84+
template<float (*bin_op)(const float, const float)>
85+
struct bin_bcast_sycl {
86+
template <typename src0_t, typename src1_t, typename dst_t>
87+
void operator()(ggml_backend_sycl_context & ctx,
88+
const struct ggml_tensor *src0,
89+
const struct ggml_tensor *src1, struct ggml_tensor *dst,
90+
const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
91+
queue_ptr stream) {
92+
93+
GGML_TENSOR_BINARY_OP_LOCALS
94+
95+
int nr0 = ne10/ne0;
96+
int nr1 = ne11/ne1;
97+
int nr2 = ne12/ne2;
98+
int nr3 = ne13/ne3;
99+
100+
int nr[4] = { nr0, nr1, nr2, nr3 };
101+
102+
// collapse dimensions until first broadcast dimension
103+
int64_t cne[] = {ne0, ne1, ne2, ne3};
104+
int64_t cne0[] = {ne00, ne01, ne02, ne03};
105+
int64_t cne1[] = {ne10, ne11, ne12, ne13};
106+
size_t cnb[] = {nb0, nb1, nb2, nb3};
107+
size_t cnb0[] = {nb00, nb01, nb02, nb03};
108+
size_t cnb1[] = {nb10, nb11, nb12, nb13};
109+
auto collapse = [](int64_t cne[]) {
110+
cne[0] *= cne[1];
111+
cne[1] = cne[2];
112+
cne[2] = cne[3];
113+
cne[3] = 1;
114+
};
115+
116+
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
117+
cnb[1] *= cne[1];
118+
cnb[2] *= cne[2];
119+
cnb[3] *= cne[3];
120+
};
121+
122+
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
123+
for (int i = 0; i < 4; i++) {
124+
if (nr[i] != 1) {
125+
break;
126+
}
127+
if (i > 0) {
128+
collapse_nb(cnb, cne);
129+
collapse_nb(cnb0, cne0);
130+
collapse_nb(cnb1, cne1);
131+
collapse(cne);
132+
collapse(cne0);
133+
collapse(cne1);
134+
}
135+
}
136+
}
137+
{
138+
int64_t ne0 = cne[0];
139+
int64_t ne1 = cne[1];
140+
int64_t ne2 = cne[2];
141+
int64_t ne3 = cne[3];
142+
143+
int64_t ne10 = cne1[0];
144+
int64_t ne11 = cne1[1];
145+
int64_t ne12 = cne1[2];
146+
int64_t ne13 = cne1[3];
147+
148+
size_t nb0 = cnb[0];
149+
size_t nb1 = cnb[1];
150+
size_t nb2 = cnb[2];
151+
size_t nb3 = cnb[3];
152+
153+
size_t nb00 = cnb0[0];
154+
size_t nb01 = cnb0[1];
155+
size_t nb02 = cnb0[2];
156+
size_t nb03 = cnb0[3];
157+
158+
size_t nb10 = cnb1[0];
159+
size_t nb11 = cnb1[1];
160+
size_t nb12 = cnb1[2];
161+
size_t nb13 = cnb1[3];
162+
163+
size_t s0 = nb0 / sizeof(dst_t);
164+
size_t s1 = nb1 / sizeof(dst_t);
165+
size_t s2 = nb2 / sizeof(dst_t);
166+
size_t s3 = nb3 / sizeof(dst_t);
167+
168+
size_t s10 = nb10 / sizeof(src1_t);
169+
size_t s11 = nb11 / sizeof(src1_t);
170+
size_t s12 = nb12 / sizeof(src1_t);
171+
size_t s13 = nb13 / sizeof(src1_t);
172+
173+
size_t s00 = nb00 / sizeof(src0_t);
174+
size_t s01 = nb01 / sizeof(src0_t);
175+
size_t s02 = nb02 / sizeof(src0_t);
176+
size_t s03 = nb03 / sizeof(src0_t);
177+
178+
GGML_UNUSED(s00);
179+
180+
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
181+
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
182+
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
183+
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
184+
185+
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
186+
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
187+
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
188+
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
189+
190+
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
191+
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
192+
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
193+
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
194+
195+
GGML_ASSERT(s0 == 1);
196+
GGML_ASSERT(s10 == 1);
197+
198+
const int block_size = 128;
199+
200+
int64_t hne0 = std::max(ne0/2LL, 1LL);
201+
202+
sycl::range<3> block_dims(1, 1, 1);
203+
block_dims[2] = std::min<unsigned int>(hne0, block_size);
204+
block_dims[1] = std::min<unsigned int>(
205+
ne1, block_size / (unsigned int)block_dims[2]);
206+
block_dims[0] = std::min(
207+
std::min<unsigned int>(
208+
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
209+
(unsigned int)block_dims[1]),
210+
64U);
211+
212+
sycl::range<3> block_nums(
213+
(ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
214+
(ne1 + block_dims[1] - 1) / block_dims[1],
215+
(hne0 + block_dims[2] - 1) / block_dims[2]);
216+
217+
if (block_nums[0] > 65535) {
218+
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
219+
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
220+
{
221+
dpct::has_capability_or_fail(stream->get_device(),
222+
{sycl::aspect::fp16});
223+
224+
stream->parallel_for(
225+
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
226+
sycl::range<3>(1, 1, block_size),
227+
sycl::range<3>(1, 1, block_size)),
228+
[=](sycl::nd_item<3> item_ct1) {
229+
k_bin_bcast_unravel<bin_op>(
230+
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
231+
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
232+
s03, s11, s12, s13, item_ct1);
233+
});
234+
}
235+
} else {
236+
/*
237+
DPCT1049:16: The work-group size passed to the SYCL kernel may
238+
exceed the limit. To get the device limit, query
239+
info::device::max_work_group_size. Adjust the work-group size if
240+
needed.
241+
*/
242+
dpct::has_capability_or_fail(stream->get_device(),
243+
{sycl::aspect::fp16});
244+
245+
stream->parallel_for(
246+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
247+
[=](sycl::nd_item<3> item_ct1) {
248+
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
249+
ne2, ne3, ne10, ne11, ne12, ne13,
250+
s1, s2, s3, s01, s02, s03, s11, s12, s13,
251+
item_ct1);
252+
});
253+
}
254+
}
255+
GGML_UNUSED(ctx);
256+
}
257+
};
258+
259+
template <class op>
260+
inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
261+
const ggml_tensor *src1, ggml_tensor *dst) {
262+
dpct::queue_ptr main_stream = ctx.stream();
263+
264+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
265+
op()(ctx, src0, src1, dst, (const float *)src0->data, (const float *)src1->data, (float *)dst->data, main_stream);
266+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
267+
op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data,
268+
(sycl::half *)dst->data, main_stream);
269+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
270+
op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data, (float *)dst->data,
271+
main_stream);
272+
} else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
273+
op()(ctx, src0, src1, dst, (const int32_t *)src0->data, (const int32_t *)src1->data, (int32_t *)dst->data,
274+
main_stream);
275+
} else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
276+
op()(ctx, src0, src1, dst, (const int16_t *)src0->data, (const int16_t *)src1->data, (int16_t *)dst->data,
277+
main_stream);
278+
} else {
279+
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
280+
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
281+
GGML_ABORT("fatal error");
282+
}
283+
}
284+
285+
inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
286+
287+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, dst->src[0], dst->src[1], dst);
288+
}
289+
290+
inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
291+
292+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
293+
}
294+
295+
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
296+
297+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
298+
}
299+
300+
inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
301+
302+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, dst->src[0], dst->src[1], dst);
303+
}
304+
305+
inline void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
306+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, dst->src[0], dst);
307+
}
308+
309+
310+
void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
311+
GGML_SYCL_DEBUG("call %s\n", __func__);
312+
ggml_sycl_op_add(ctx, dst);
313+
GGML_SYCL_DEBUG("call %s done\n", __func__);
314+
}
315+
316+
void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
317+
GGML_SYCL_DEBUG("call %s\n", __func__);
318+
ggml_sycl_op_sub(ctx, dst);
319+
GGML_SYCL_DEBUG("call %s done\n", __func__);
320+
}
321+
322+
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
323+
GGML_SYCL_DEBUG("call %s\n", __func__);
324+
ggml_sycl_op_mul(ctx, dst);
325+
GGML_SYCL_DEBUG("call %s done\n", __func__);
326+
}
327+
328+
void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
329+
GGML_SYCL_DEBUG("call %s\n", __func__);
330+
ggml_sycl_op_div(ctx, dst);
331+
GGML_SYCL_DEBUG("call %s done\n", __func__);
332+
}
333+
334+
void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
335+
GGML_SYCL_DEBUG("call %s\n", __func__);
336+
ggml_sycl_op_repeat(ctx, dst);
337+
GGML_SYCL_DEBUG("call %s done\n", __func__);
338+
}
339+

ggml/src/ggml-sycl/binbcast.hpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#ifndef GGML_SYCL_BINBCAST_HPP
2+
#define GGML_SYCL_BINBCAST_HPP
3+
#include "common.hpp"
4+
5+
6+
static __dpct_inline__ float op_repeat(const float a, const float b) {
7+
return b;
8+
GGML_UNUSED(a);
9+
}
10+
11+
static __dpct_inline__ float op_add(const float a, const float b) {
12+
return a + b;
13+
}
14+
15+
static __dpct_inline__ float op_sub(const float a, const float b) {
16+
return a - b;
17+
}
18+
19+
static __dpct_inline__ float op_mul(const float a, const float b) {
20+
return a * b;
21+
}
22+
23+
static __dpct_inline__ float op_div(const float a, const float b) {
24+
return a / b;
25+
}
26+
27+
void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
28+
29+
void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
30+
31+
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
32+
33+
void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
34+
35+
void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
36+
37+
38+
#endif //GGML_SYCL_BINBCAST_HPP
39+

0 commit comments

Comments
 (0)