Skip to content

Commit 515fdbf

Browse files
authored
SYCL: revert "sycl: simplify bin_bcast_kernel (ggml-org#13383)" (ggml-org#13752)
Temporarily reverted due to failing fp16 DIV operation This reverts commit 02cdd2d. ggml-ci
1 parent f5cd27b commit 515fdbf

File tree

1 file changed

+232
-121
lines changed

1 file changed

+232
-121
lines changed

ggml/src/ggml-sycl/binbcast.cpp

Lines changed: 232 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,93 @@
11
#include "binbcast.hpp"
22

3-
#include <array>
43
#include <cstddef>
54
#include <cstdint>
65
#include <sycl/sycl.hpp>
76

8-
#include "dpct/helper.hpp"
97
#include "ggml.h"
108

11-
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
12-
static __dpct_inline__ void k_bin_bcast_contiguous(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1,
13-
dst_t * dst, std::size_t num_elements, const sycl::nd_item<1> & it) {
14-
auto element_id = it.get_global_id(0);
15-
auto global_range = it.get_global_range(0);
16-
for (; element_id < num_elements; element_id += global_range) {
17-
auto src0_float_val = sycl::vec(src0[element_id]).template convert<float, sycl::rounding_mode::rte>();
18-
auto src1_float_val = sycl::vec(src1[element_id]).template convert<float, sycl::rounding_mode::rte>();
19-
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
20-
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
21-
dst[element_id] = val_to_store;
9+
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
10+
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
11+
int ne0, int ne1, int ne2, int ne3,
12+
int ne10, int ne11, int ne12, int ne13,
13+
/*int s0, */ int s1, int s2, int s3,
14+
/*int s00,*/ int s01, int s02, int s03,
15+
/*int s10,*/ int s11, int s12, int s13,
16+
const sycl::nd_item<3> &item_ct1) {
17+
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
18+
item_ct1.get_local_id(2);
19+
const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
20+
item_ct1.get_local_id(1));
21+
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
22+
item_ct1.get_local_id(0)) /
23+
ne3;
24+
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
25+
item_ct1.get_local_id(0)) %
26+
ne3;
27+
28+
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
29+
return;
30+
}
31+
32+
const int i11 = i1 % ne11;
33+
const int i12 = i2 % ne12;
34+
const int i13 = i3 % ne13;
35+
36+
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
37+
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
38+
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
39+
40+
const src0_t * src0_row = src0 + i_src0;
41+
const src1_t * src1_row = src1 + i_src1;
42+
dst_t * dst_row = dst + i_dst;
43+
44+
for (int i0 = i0s; i0 < ne0;
45+
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
46+
const int i10 = i0 % ne10;
47+
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
2248
}
2349
}
2450

25-
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
26-
static __dpct_inline__ void k_bin_bcast(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, dst_t * dst,
27-
int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13,
28-
int s0, int s1, int s2, int s3, int s00, int s01, int s02, int s03, int s10,
29-
int s11, int s12, int s13, std::size_t num_dst_elements,
30-
const sycl::nd_item<1> & item_ct1) {
31-
auto calculate_logical_index =
32-
[](const std::array<int, 4> & dims, std::size_t element_id) __attribute__((always_inline))->std::array<int, 4> {
33-
std::array<int, 4> logical_index;
34-
#pragma unroll(4)
35-
for (int i = 3; i >= 0; i--) {
36-
logical_index[i] = element_id % dims[i];
37-
element_id /= dims[i];
38-
}
39-
return logical_index;
40-
};
41-
42-
auto calculate_index = [](const std::array<int, 4> & dims, const std::array<int, 4> & strides,
43-
const std::array<int, 4> & indices) __attribute__((always_inline))
44-
->std::size_t {
45-
std::size_t index = 0;
46-
#pragma unroll(4)
47-
for (int i = 0; i < 4; i++) {
48-
auto index_i = indices[i];
49-
if (indices[i] >= dims[i]) {
50-
index_i = indices[i] % dims[i];
51-
}
52-
index += strides[i] * index_i;
53-
}
54-
return index;
55-
};
56-
57-
auto element_id = item_ct1.get_global_id(0);
58-
for (; element_id < num_dst_elements; element_id += item_ct1.get_global_range(0)) {
59-
auto logical_index = calculate_logical_index({ ne3, ne2, ne1, ne0 }, element_id);
60-
auto src_0_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s03, s02, s01, s00 }, logical_index);
61-
auto src_1_index = calculate_index({ ne13, ne12, ne11, ne10 }, { s13, s12, s11, s10 }, logical_index);
62-
auto dst_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s3, s2, s1, s0 }, logical_index);
63-
auto src0_float_val = sycl::vec(src0[src_0_index]).template convert<float, sycl::rounding_mode::rte>();
64-
auto src1_float_val = sycl::vec(src1[src_1_index]).template convert<float, sycl::rounding_mode::rte>();
65-
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
66-
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
67-
dst[dst_index] = val_to_store;
51+
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
52+
static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
53+
int ne0, int ne1, int ne2, int ne3,
54+
int ne10, int ne11, int ne12, int ne13,
55+
/*int s0, */ int s1, int s2, int s3,
56+
/*int s00,*/ int s01, int s02, int s03,
57+
/*int s10,*/ int s11, int s12, int s13,
58+
const sycl::nd_item<3> &item_ct1) {
59+
60+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
61+
item_ct1.get_local_id(2);
62+
63+
const int i3 = i/(ne2*ne1*ne0);
64+
const int i2 = (i/(ne1*ne0)) % ne2;
65+
const int i1 = (i/ne0) % ne1;
66+
const int i0 = i % ne0;
67+
68+
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
69+
return;
6870
}
71+
72+
const int i11 = i1 % ne11;
73+
const int i12 = i2 % ne12;
74+
const int i13 = i3 % ne13;
75+
76+
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
77+
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
78+
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
79+
80+
const src0_t * src0_row = src0 + i_src0;
81+
const src1_t * src1_row = src1 + i_src1;
82+
dst_t * dst_row = dst + i_dst;
83+
84+
const int i10 = i0 % ne10;
85+
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
6986
}
7087

71-
template <float (*bin_op)(const float, const float)> struct bin_bcast_sycl {
88+
89+
template<float (*bin_op)(const float, const float)>
90+
struct bin_bcast_sycl {
7291
template <typename src0_t, typename src1_t, typename dst_t>
7392
void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
7493
const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
@@ -77,73 +96,165 @@ template <float (*bin_op)(const float, const float)> struct bin_bcast_sycl {
7796
const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
7897
const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
7998
const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
80-
auto check_bcast_required = [](const std::array<int64_t, 4> & src_dims,
81-
const std::array<int64_t, 4> & dst_dims) -> bool {
99+
int nr0 = ne10 / ne0;
100+
int nr1 = ne11/ne1;
101+
int nr2 = ne12/ne2;
102+
int nr3 = ne13/ne3;
103+
104+
int nr[4] = { nr0, nr1, nr2, nr3 };
105+
106+
// collapse dimensions until first broadcast dimension
107+
int64_t cne[] = {ne0, ne1, ne2, ne3};
108+
int64_t cne0[] = {ne00, ne01, ne02, ne03};
109+
int64_t cne1[] = {ne10, ne11, ne12, ne13};
110+
size_t cnb[] = {nb0, nb1, nb2, nb3};
111+
size_t cnb0[] = {nb00, nb01, nb02, nb03};
112+
size_t cnb1[] = {nb10, nb11, nb12, nb13};
113+
auto collapse = [](int64_t cne[]) {
114+
cne[0] *= cne[1];
115+
cne[1] = cne[2];
116+
cne[2] = cne[3];
117+
cne[3] = 1;
118+
};
119+
120+
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
121+
cnb[1] *= cne[1];
122+
cnb[2] *= cne[2];
123+
cnb[3] *= cne[3];
124+
};
125+
126+
if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
82127
for (int i = 0; i < 4; i++) {
83-
if (dst_dims[i] > src_dims[i]) {
84-
return true;
128+
if (nr[i] != 1) {
129+
break;
130+
}
131+
if (i > 0) {
132+
collapse_nb(cnb, cne);
133+
collapse_nb(cnb0, cne0);
134+
collapse_nb(cnb1, cne1);
135+
collapse(cne);
136+
collapse(cne0);
137+
collapse(cne1);
85138
}
86139
}
87-
return false;
88-
};
89-
90-
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
91-
92-
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
93-
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
94-
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
95-
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
96-
97-
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
98-
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
99-
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
100-
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
101-
102-
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
103-
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
104-
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
105-
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
106-
107-
// dst strides in number of elements
108-
size_t s0 = nb0 / sizeof(dst_t);
109-
size_t s1 = nb1 / sizeof(dst_t);
110-
size_t s2 = nb2 / sizeof(dst_t);
111-
size_t s3 = nb3 / sizeof(dst_t);
112-
113-
// src1 strides in number of elements
114-
size_t s10 = nb10 / sizeof(src0_t);
115-
size_t s11 = nb11 / sizeof(src1_t);
116-
size_t s12 = nb12 / sizeof(src1_t);
117-
size_t s13 = nb13 / sizeof(src1_t);
118-
119-
// src0 strides in number of elements
120-
size_t s00 = nb00 / sizeof(src0_t);
121-
size_t s01 = nb01 / sizeof(src0_t);
122-
size_t s02 = nb02 / sizeof(src0_t);
123-
size_t s03 = nb03 / sizeof(src0_t);
124-
125-
std::size_t num_dst_elements = static_cast<std::size_t>(ne0) * static_cast<std::size_t>(ne1) *
126-
static_cast<std::size_t>(ne2) * static_cast<std::size_t>(ne3);
127-
std::size_t local_range = 256;
128-
std::size_t global_range = ceil_div(num_dst_elements, local_range) * local_range;
129-
130-
bool needs_broadcasting = check_bcast_required({ ne00, ne01, ne02, ne03 }, { ne0, ne1, ne2, ne3 }) ||
131-
check_bcast_required({ ne10, ne11, ne12, ne13 }, { ne0, ne1, ne2, ne3 });
132-
bool all_contiguous = src0_is_contiguous && src1_is_contiguous && dst_is_contiguous;
133-
134-
if (! needs_broadcasting && all_contiguous) {
135-
stream->submit([&](sycl::handler & cgh) {
136-
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
137-
k_bin_bcast_contiguous<bin_op>(src0_dd, src1_dd, dst_dd, num_dst_elements, it);
138-
});
139-
});
140-
} else {
141-
stream->submit([&](sycl::handler & cgh) {
142-
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
143-
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s0, s1,
144-
s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, num_dst_elements, it);
145-
});
146-
});
140+
}
141+
{
142+
int64_t ne0 = cne[0];
143+
int64_t ne1 = cne[1];
144+
int64_t ne2 = cne[2];
145+
int64_t ne3 = cne[3];
146+
147+
int64_t ne10 = cne1[0];
148+
int64_t ne11 = cne1[1];
149+
int64_t ne12 = cne1[2];
150+
int64_t ne13 = cne1[3];
151+
152+
size_t nb0 = cnb[0];
153+
size_t nb1 = cnb[1];
154+
size_t nb2 = cnb[2];
155+
size_t nb3 = cnb[3];
156+
157+
size_t nb00 = cnb0[0];
158+
size_t nb01 = cnb0[1];
159+
size_t nb02 = cnb0[2];
160+
size_t nb03 = cnb0[3];
161+
162+
size_t nb10 = cnb1[0];
163+
size_t nb11 = cnb1[1];
164+
size_t nb12 = cnb1[2];
165+
size_t nb13 = cnb1[3];
166+
167+
size_t s0 = nb0 / sizeof(dst_t);
168+
size_t s1 = nb1 / sizeof(dst_t);
169+
size_t s2 = nb2 / sizeof(dst_t);
170+
size_t s3 = nb3 / sizeof(dst_t);
171+
172+
size_t s10 = nb10 / sizeof(src1_t);
173+
size_t s11 = nb11 / sizeof(src1_t);
174+
size_t s12 = nb12 / sizeof(src1_t);
175+
size_t s13 = nb13 / sizeof(src1_t);
176+
177+
size_t s00 = nb00 / sizeof(src0_t);
178+
size_t s01 = nb01 / sizeof(src0_t);
179+
size_t s02 = nb02 / sizeof(src0_t);
180+
size_t s03 = nb03 / sizeof(src0_t);
181+
182+
GGML_UNUSED(s00);
183+
184+
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
185+
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
186+
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
187+
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
188+
189+
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
190+
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
191+
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
192+
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
193+
194+
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
195+
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
196+
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
197+
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
198+
199+
GGML_ASSERT(s0 == 1);
200+
GGML_ASSERT(s10 == 1);
201+
202+
const int block_size = 128;
203+
204+
int64_t hne0 = std::max(ne0/2LL, 1LL);
205+
206+
sycl::range<3> block_dims(1, 1, 1);
207+
block_dims[2] = std::min<unsigned int>(hne0, block_size);
208+
block_dims[1] = std::min<unsigned int>(
209+
ne1, block_size / (unsigned int)block_dims[2]);
210+
block_dims[0] = std::min(
211+
std::min<unsigned int>(
212+
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
213+
(unsigned int)block_dims[1]),
214+
64U);
215+
216+
sycl::range<3> block_nums(
217+
(ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
218+
(ne1 + block_dims[1] - 1) / block_dims[1],
219+
(hne0 + block_dims[2] - 1) / block_dims[2]);
220+
221+
if (block_nums[0] > 65535) {
222+
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
223+
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
224+
{
225+
dpct::has_capability_or_fail(stream->get_device(),
226+
{sycl::aspect::fp16});
227+
228+
stream->parallel_for(
229+
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
230+
sycl::range<3>(1, 1, block_size),
231+
sycl::range<3>(1, 1, block_size)),
232+
[=](sycl::nd_item<3> item_ct1) {
233+
k_bin_bcast_unravel<bin_op>(
234+
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
235+
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
236+
s03, s11, s12, s13, item_ct1);
237+
});
238+
}
239+
} else {
240+
/*
241+
DPCT1049:16: The work-group size passed to the SYCL kernel may
242+
exceed the limit. To get the device limit, query
243+
info::device::max_work_group_size. Adjust the work-group size if
244+
needed.
245+
*/
246+
dpct::has_capability_or_fail(stream->get_device(),
247+
{sycl::aspect::fp16});
248+
249+
stream->parallel_for(
250+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
251+
[=](sycl::nd_item<3> item_ct1) {
252+
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
253+
ne2, ne3, ne10, ne11, ne12, ne13,
254+
s1, s2, s3, s01, s02, s03, s11, s12, s13,
255+
item_ct1);
256+
});
257+
}
147258
}
148259
}
149260
};

0 commit comments

Comments
 (0)