Skip to content

Commit d9d2e1d

Browse files
committed
getrows: move to a separate file
1 parent 6425a70 commit d9d2e1d

File tree

4 files changed

+174
-191
lines changed

4 files changed

+174
-191
lines changed

ggml/src/ggml-sycl/backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "argmax.hpp"
3434
#include "argsort.hpp"
3535
#include "cpy.hpp"
36+
#include "getrows.hpp"
3637
#include "gla.hpp"
3738

3839
#endif // GGML_SYCL_BACKEND_HPP

ggml/src/ggml-sycl/getrows.cpp

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
#include "getrows.hpp"
2+
#include "dequantize.hpp"
3+
4+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
5+
static void k_get_rows(const void * src0, const int32_t * src1, dst_t * dst,
6+
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
7+
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
8+
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
9+
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, size_t s10, size_t s11, size_t s12,
10+
const sycl::nd_item<3> & item_ct1 /*, size_t s13*/) {
11+
const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2)) * 2;
12+
const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1);
13+
const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + item_ct1.get_local_id(0)) / ne12;
14+
const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + item_ct1.get_local_id(0)) % ne12;
15+
16+
if (i00 >= ne00) {
17+
return;
18+
}
19+
20+
const int i01 = src1[i10 * s10 + i11 * s11 + i12 * s12];
21+
22+
dst_t * dst_row = dst + i10 * s1 + i11 * s2 + i12 * s3;
23+
const void * src0_row = (const char *) src0 + i01 * nb01 + i11 * nb02 + i12 * nb03;
24+
25+
const int ib = i00 / qk; // block index
26+
const int iqs = (i00 % qk) / qr; // quant index
27+
const int iybs = i00 - i00 % qk; // dst block start index
28+
const int y_offset = qr == 1 ? 1 : qk / 2;
29+
30+
// dequantize
31+
dfloat2 v;
32+
dequantize_kernel(src0_row, ib, iqs, v);
33+
34+
dst_row[iybs + iqs + 0] = v.x();
35+
dst_row[iybs + iqs + y_offset] = v.y();
36+
}
37+
38+
template <typename src0_t, typename dst_t>
39+
static void k_get_rows_float(const src0_t * src0, const int32_t * src1, dst_t * dst,
40+
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
41+
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
42+
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
43+
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, size_t s10, size_t s11, size_t s12,
44+
const sycl::nd_item<3> & item_ct1 /*, size_t s13*/) {
45+
const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
46+
const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1);
47+
const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + item_ct1.get_local_id(0)) / ne12;
48+
const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) + item_ct1.get_local_id(0)) % ne12;
49+
50+
if (i00 >= ne00) {
51+
return;
52+
}
53+
54+
const int i01 = src1[i10 * s10 + i11 * s11 + i12 * s12];
55+
56+
dst_t * dst_row = dst + i10 * s1 + i11 * s2 + i12 * s3;
57+
const src0_t * src0_row = (const src0_t *) ((const char *) src0 + i01 * nb01 + i11 * nb02 + i12 * nb03);
58+
59+
dst_row[i00] = src0_row[i00];
60+
}
61+
62+
template <int qk, int qr, dequantize_kernel_t dq>
63+
static void get_rows_sycl(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
64+
GGML_SYCL_TENSOR_BINARY_OP_LOCALS
65+
66+
const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
67+
const int block_num_x = (ne00 + 2 * SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2 * SYCL_GET_ROWS_BLOCK_SIZE);
68+
const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
69+
70+
// strides in elements
71+
//const size_t s0 = nb0 / ggml_element_size(dst);
72+
const size_t s1 = nb1 / ggml_element_size(dst);
73+
const size_t s2 = nb2 / ggml_element_size(dst);
74+
const size_t s3 = nb3 / ggml_element_size(dst);
75+
76+
const size_t s10 = nb10 / ggml_element_size(dst->src[1]);
77+
const size_t s11 = nb11 / ggml_element_size(dst->src[1]);
78+
const size_t s12 = nb12 / ggml_element_size(dst->src[1]);
79+
//const size_t s13 = nb13 / ggml_element_size(dst->src[1]);
80+
81+
GGML_ASSERT(ne00 % 2 == 0);
82+
const void * src0_dd = dst->src[0]->data;
83+
const int32_t * src1_dd = static_cast<const int32_t *>(dst->src[1]->data);
84+
float * dst_dd = static_cast<float *>(dst->data);
85+
86+
dpct::queue_ptr stream = ctx.stream();
87+
88+
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
89+
k_get_rows<qk, qr, dq>(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, s3, nb01, nb02, nb03, s10, s11, s12,
90+
item_ct1);
91+
});
92+
}
93+
94+
template <typename src0_t> static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
95+
GGML_SYCL_TENSOR_BINARY_OP_LOCALS
96+
97+
const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
98+
const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
99+
const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
100+
101+
// strides in elements
102+
//const size_t s0 = nb0 / ggml_element_size(dst);
103+
const size_t s1 = nb1 / ggml_element_size(dst);
104+
const size_t s2 = nb2 / ggml_element_size(dst);
105+
const size_t s3 = nb3 / ggml_element_size(dst);
106+
107+
const size_t s10 = nb10 / ggml_element_size(dst->src[1]);
108+
const size_t s11 = nb11 / ggml_element_size(dst->src[1]);
109+
const size_t s12 = nb12 / ggml_element_size(dst->src[1]);
110+
//const size_t s13 = nb13 / ggml_element_size(dst->src[1]);
111+
const src0_t * src0_dd = static_cast<const src0_t *>(dst->src[0]->data);
112+
const int32_t * src1_dd = static_cast<const int32_t *>(dst->src[1]->data);
113+
float * dst_dd = static_cast<float *>(dst->data);
114+
115+
dpct::queue_ptr stream = ctx.stream();
116+
117+
{
118+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
119+
120+
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
121+
k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, s3, nb01, nb02, nb03, s10, s11, s12,
122+
item_ct1);
123+
});
124+
}
125+
}
126+
127+
void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
128+
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I32);
129+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
130+
131+
GGML_ASSERT(dst->src[0]->nb[0] == ggml_type_size(dst->src[0]->type));
132+
GGML_ASSERT(dst->src[1]->nb[0] == ggml_type_size(dst->src[1]->type));
133+
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
134+
GGML_ASSERT(strcmp(dst->src[1]->buffer->buft->iface.get_name(dst->src[1]->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
135+
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
136+
137+
switch (dst->src[0]->type) {
138+
case GGML_TYPE_F16:
139+
get_rows_sycl_float<sycl::half>(ctx, dst);
140+
break;
141+
case GGML_TYPE_F32:
142+
get_rows_sycl_float<float>(ctx, dst);
143+
break;
144+
case GGML_TYPE_Q4_0:
145+
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst);
146+
break;
147+
case GGML_TYPE_Q4_1:
148+
get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, dst);
149+
break;
150+
case GGML_TYPE_Q5_0:
151+
get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, dst);
152+
break;
153+
case GGML_TYPE_Q5_1:
154+
get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, dst);
155+
break;
156+
case GGML_TYPE_Q8_0:
157+
get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, dst);
158+
break;
159+
default:
160+
// TODO: k-quants
161+
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(dst->src[0]->type));
162+
GGML_ABORT("fatal error");
163+
break;
164+
}
165+
}

ggml/src/ggml-sycl/getrows.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef GGML_SYCL_GETROWS_HPP
2+
#define GGML_SYCL_GETROWS_HPP
3+
4+
#include "common.hpp"
5+
6+
void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7+
8+
#endif // GGML_SYCL_GETROWS_HPP

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 0 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,83 +1336,6 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy,
13361336
reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
13371337
}
13381338

1339-
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
1340-
static void k_get_rows(
1341-
const void * src0, const int32_t * src1, dst_t * dst,
1342-
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
1343-
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
1344-
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
1345-
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
1346-
size_t s10, size_t s11, size_t s12,
1347-
const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
1348-
1349-
const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
1350-
item_ct1.get_local_id(2)) *
1351-
2;
1352-
const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1353-
item_ct1.get_local_id(1);
1354-
const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1355-
item_ct1.get_local_id(0)) /
1356-
ne12;
1357-
const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1358-
item_ct1.get_local_id(0)) %
1359-
ne12;
1360-
1361-
if (i00 >= ne00) {
1362-
return;
1363-
}
1364-
1365-
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
1366-
1367-
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
1368-
const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
1369-
1370-
const int ib = i00/qk; // block index
1371-
const int iqs = (i00%qk)/qr; // quant index
1372-
const int iybs = i00 - i00%qk; // dst block start index
1373-
const int y_offset = qr == 1 ? 1 : qk/2;
1374-
1375-
// dequantize
1376-
dfloat2 v;
1377-
dequantize_kernel(src0_row, ib, iqs, v);
1378-
1379-
dst_row[iybs + iqs + 0] = v.x();
1380-
dst_row[iybs + iqs + y_offset] = v.y();
1381-
}
1382-
1383-
template<typename src0_t, typename dst_t>
1384-
static void k_get_rows_float(
1385-
const src0_t * src0, const int32_t * src1, dst_t * dst,
1386-
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
1387-
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
1388-
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
1389-
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
1390-
size_t s10, size_t s11, size_t s12,
1391-
const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
1392-
1393-
const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
1394-
item_ct1.get_local_id(2);
1395-
const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1396-
item_ct1.get_local_id(1);
1397-
const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1398-
item_ct1.get_local_id(0)) /
1399-
ne12;
1400-
const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1401-
item_ct1.get_local_id(0)) %
1402-
ne12;
1403-
1404-
if (i00 >= ne00) {
1405-
return;
1406-
}
1407-
1408-
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
1409-
1410-
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
1411-
const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
1412-
1413-
dst_row[i00] = src0_row[i00];
1414-
}
1415-
14161339
static void mul_mat_p021_f16_f32(
14171340
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
14181341
const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
@@ -1644,80 +1567,6 @@ static void pool2d_nchw_kernel(
16441567
o_ptr[cur_oh * ow + cur_ow] = res;
16451568
}
16461569

1647-
template <int qk, int qr, dequantize_kernel_t dq>
1648-
static void get_rows_sycl(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
1649-
1650-
GGML_SYCL_TENSOR_BINARY_OP_LOCALS
1651-
1652-
const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
1653-
const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
1654-
const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
1655-
1656-
// strides in elements
1657-
//const size_t s0 = nb0 / ggml_element_size(dst);
1658-
const size_t s1 = nb1 / ggml_element_size(dst);
1659-
const size_t s2 = nb2 / ggml_element_size(dst);
1660-
const size_t s3 = nb3 / ggml_element_size(dst);
1661-
1662-
const size_t s10 = nb10 / ggml_element_size(dst->src[1]);
1663-
const size_t s11 = nb11 / ggml_element_size(dst->src[1]);
1664-
const size_t s12 = nb12 / ggml_element_size(dst->src[1]);
1665-
//const size_t s13 = nb13 / ggml_element_size(dst->src[1]);
1666-
1667-
GGML_ASSERT(ne00 % 2 == 0);
1668-
const void * src0_dd = dst->src[0]->data;
1669-
const int32_t * src1_dd = static_cast<const int32_t *>(dst->src[1]->data);
1670-
float * dst_dd = static_cast<float *>(dst->data);
1671-
1672-
dpct::queue_ptr stream = ctx.stream();
1673-
1674-
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
1675-
[=](sycl::nd_item<3> item_ct1) {
1676-
k_get_rows<qk, qr, dq>(
1677-
src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
1678-
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
1679-
});
1680-
1681-
}
1682-
1683-
template <typename src0_t>
1684-
static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1685-
1686-
GGML_SYCL_TENSOR_BINARY_OP_LOCALS
1687-
1688-
const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
1689-
const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
1690-
const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
1691-
1692-
// strides in elements
1693-
//const size_t s0 = nb0 / ggml_element_size(dst);
1694-
const size_t s1 = nb1 / ggml_element_size(dst);
1695-
const size_t s2 = nb2 / ggml_element_size(dst);
1696-
const size_t s3 = nb3 / ggml_element_size(dst);
1697-
1698-
const size_t s10 = nb10 / ggml_element_size(dst->src[1]);
1699-
const size_t s11 = nb11 / ggml_element_size(dst->src[1]);
1700-
const size_t s12 = nb12 / ggml_element_size(dst->src[1]);
1701-
//const size_t s13 = nb13 / ggml_element_size(dst->src[1]);
1702-
const src0_t * src0_dd = static_cast<const src0_t *>(dst->src[0]->data);
1703-
const int32_t * src1_dd = static_cast<const int32_t *>(dst->src[1]->data);
1704-
float * dst_dd = static_cast<float *>(dst->data);
1705-
1706-
dpct::queue_ptr stream = ctx.stream();
1707-
1708-
{
1709-
dpct::has_capability_or_fail(stream->get_device(),
1710-
{sycl::aspect::fp16});
1711-
1712-
stream->parallel_for(
1713-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
1714-
[=](sycl::nd_item<3> item_ct1) {
1715-
k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
1716-
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
1717-
});
1718-
}
1719-
}
1720-
17211570
static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
17221571
const int ky, const int kx_padded,
17231572
queue_ptr stream) {
@@ -1912,46 +1761,6 @@ catch (sycl::exception const &exc) {
19121761
std::exit(1);
19131762
}
19141763

1915-
static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
1916-
1917-
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I32);
1918-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
1919-
1920-
GGML_ASSERT(dst->src[0]->nb[0] == ggml_type_size(dst->src[0]->type));
1921-
GGML_ASSERT(dst->src[1]->nb[0] == ggml_type_size(dst->src[1]->type));
1922-
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
1923-
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->src[1]->buffer));
1924-
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));
1925-
1926-
switch (dst->src[0]->type) {
1927-
case GGML_TYPE_F16:
1928-
get_rows_sycl_float<sycl::half>(ctx, dst);
1929-
break;
1930-
case GGML_TYPE_F32:
1931-
get_rows_sycl_float<float>(ctx, dst);
1932-
break;
1933-
case GGML_TYPE_Q4_0:
1934-
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst);
1935-
break;
1936-
case GGML_TYPE_Q4_1:
1937-
get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, dst);
1938-
break;
1939-
case GGML_TYPE_Q5_0:
1940-
get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, dst);
1941-
break;
1942-
case GGML_TYPE_Q5_1:
1943-
get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, dst);
1944-
break;
1945-
case GGML_TYPE_Q8_0:
1946-
get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, dst);
1947-
break;
1948-
default:
1949-
// TODO: k-quants
1950-
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(dst->src[0]->type));
1951-
GGML_ABORT("fatal error");
1952-
break;
1953-
}
1954-
}
19551764

19561765
inline void ggml_sycl_op_mul_mat_sycl(
19571766
ggml_backend_sycl_context & ctx,

0 commit comments

Comments
 (0)