Skip to content

Commit e5b79de

Browse files
committed
diagmask: move to a separate file
1 parent d9d2e1d commit e5b79de

File tree

4 files changed

+62
-55
lines changed

4 files changed

+62
-55
lines changed

ggml/src/ggml-sycl/backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "argsort.hpp"
3535
#include "cpy.hpp"
3636
#include "getrows.hpp"
37+
#include "diagmask.hpp"
3738
#include "gla.hpp"
3839

3940
#endif // GGML_SYCL_BACKEND_HPP

ggml/src/ggml-sycl/diagmask.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#include "diagmask.hpp"
2+
#include <float.h>
3+
4+
static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel,
5+
const int n_past, const sycl::nd_item<3> & item_ct1) {
6+
const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1);
7+
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
8+
9+
if (col >= ncols) {
10+
return;
11+
}
12+
13+
const int i = row * ncols + col;
14+
//dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
15+
//dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
16+
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
17+
}
18+
19+
static void diag_mask_inf_f32_sycl(const float * x, float * dst, const int ncols_x, const int nrows_x,
20+
const int rows_per_channel, const int n_past, queue_ptr stream) {
21+
const sycl::range<3> block_dims(1, SYCL_DIAG_MASK_INF_BLOCK_SIZE, 1);
22+
const int block_num_x = (ncols_x + SYCL_DIAG_MASK_INF_BLOCK_SIZE - 1) / SYCL_DIAG_MASK_INF_BLOCK_SIZE;
23+
const sycl::range<3> block_nums(1, block_num_x, nrows_x);
24+
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
25+
diag_mask_inf_f32(x, dst, ncols_x, rows_per_channel, n_past, item_ct1);
26+
});
27+
}
28+
29+
inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
30+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
31+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
32+
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
33+
34+
const int64_t ne00 = dst->src[0]->ne[0];
35+
const int64_t ne01 = dst->src[0]->ne[1];
36+
const int nrows0 = ggml_nrows(dst->src[0]);
37+
38+
const int n_past = ((int32_t *) dst->op_params)[0];
39+
dpct::queue_ptr main_stream = ctx.stream();
40+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
41+
float * dst_dd = static_cast<float *>(dst->data);
42+
43+
diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
44+
} catch (const sycl::exception & exc) {
45+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
46+
std::exit(1);
47+
}
48+
49+
void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
50+
GGML_SYCL_DEBUG("call %s\n", __func__);
51+
ggml_sycl_op_diag_mask_inf(ctx, dst);
52+
GGML_SYCL_DEBUG("call %s done\n", __func__);
53+
}

ggml/src/ggml-sycl/diagmask.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef GGML_SYCL_DIAG_MASK
2+
#define GGML_SYCL_DIAG_MASK
3+
4+
#include "common.hpp"
5+
6+
void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7+
8+
#endif // GGML_SYCL_DIAG_MASK

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,24 +1463,6 @@ static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
14631463
}
14641464
}
14651465

1466-
1467-
static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
1468-
const sycl::nd_item<3> &item_ct1) {
1469-
const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1470-
item_ct1.get_local_id(1);
1471-
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1472-
item_ct1.get_local_id(2);
1473-
1474-
if (col >= ncols) {
1475-
return;
1476-
}
1477-
1478-
const int i = row*ncols + col;
1479-
//dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
1480-
//dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
1481-
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
1482-
}
1483-
14841466
static void scale_f32(const float * x, float * dst, const float scale, const int k,
14851467
const sycl::nd_item<3> &item_ct1) {
14861468
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@@ -1666,21 +1648,6 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
16661648
});
16671649
}
16681650

1669-
static void diag_mask_inf_f32_sycl(const float *x, float *dst,
1670-
const int ncols_x, const int nrows_x,
1671-
const int rows_per_channel, const int n_past,
1672-
queue_ptr stream) {
1673-
const sycl::range<3> block_dims(1, SYCL_DIAG_MASK_INF_BLOCK_SIZE, 1);
1674-
const int block_num_x = (ncols_x + SYCL_DIAG_MASK_INF_BLOCK_SIZE - 1) / SYCL_DIAG_MASK_INF_BLOCK_SIZE;
1675-
const sycl::range<3> block_nums(1, block_num_x, nrows_x);
1676-
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
1677-
[=](sycl::nd_item<3> item_ct1) {
1678-
diag_mask_inf_f32(x, dst, ncols_x,
1679-
rows_per_channel, n_past,
1680-
item_ct1);
1681-
});
1682-
}
1683-
16841651
static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
16851652
const struct ggml_tensor *src,
16861653
int64_t i3, int64_t i2,
@@ -1962,24 +1929,6 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
19621929
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
19631930
}
19641931

1965-
inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1966-
1967-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1968-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
1969-
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));
1970-
1971-
const int64_t ne00 = dst->src[0]->ne[0];
1972-
const int64_t ne01 = dst->src[0]->ne[1];
1973-
const int nrows0 = ggml_nrows(dst->src[0]);
1974-
1975-
const int n_past = ((int32_t *) dst->op_params)[0];
1976-
dpct::queue_ptr main_stream = ctx.stream();
1977-
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
1978-
float * dst_dd = static_cast<float *>(dst->data);
1979-
1980-
diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
1981-
}
1982-
19831932
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
19841933

19851934
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
@@ -2957,10 +2906,6 @@ static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
29572906
ggml_sycl_cpy(ctx, dst->src[0], dst);
29582907
}
29592908

2960-
static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2961-
ggml_sycl_op_diag_mask_inf(ctx, dst);
2962-
}
2963-
29642909
static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
29652910
GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
29662911
ggml_sycl_op_rope(ctx, dst);

0 commit comments

Comments
 (0)