Skip to content

Commit 4dc3ad3

Browse files
committed
clamp: move to a separate file
1 parent 4c6d4c8 commit 4dc3ad3

File tree

4 files changed

+60
-51
lines changed

4 files changed

+60
-51
lines changed

ggml/src/ggml-sycl/backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "getrows.hpp"
3737
#include "diagmask.hpp"
3838
#include "scale.hpp"
39+
#include "clamp.hpp"
3940
#include "gla.hpp"
4041

4142
#endif // GGML_SYCL_BACKEND_HPP

ggml/src/ggml-sycl/clamp.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#include "clamp.hpp"
2+
3+
static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
4+
const sycl::nd_item<3> & item_ct1) {
5+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
6+
7+
if (i >= k) {
8+
return;
9+
}
10+
11+
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
12+
}
13+
14+
static void clamp_f32_sycl(const float * x, float * dst, const float min, const float max, const int k,
15+
queue_ptr stream) {
16+
const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
17+
stream->parallel_for(
18+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
19+
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
20+
[=](sycl::nd_item<3> item_ct1) { clamp_f32(x, dst, min, max, k, item_ct1); });
21+
}
22+
23+
inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
24+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
25+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
26+
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
27+
28+
float min;
29+
float max;
30+
memcpy(&min, dst->op_params, sizeof(float));
31+
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
32+
const dpct::queue_ptr main_stream = ctx.stream();
33+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
34+
float * dst_dd = static_cast<float *>(dst->data);
35+
36+
clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(dst->src[0]), main_stream);
37+
/*
38+
DPCT1010:88: SYCL uses exceptions to report errors and does not use the
39+
error codes. The call was replaced with 0. You need to rewrite this code.
40+
SYCL_CHECK(0);
41+
*/
42+
} catch (const sycl::exception & exc) {
43+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
44+
std::exit(1);
45+
}
46+
47+
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
48+
GGML_SYCL_DEBUG("call %s\n", __func__);
49+
ggml_sycl_op_clamp(ctx, dst);
50+
GGML_SYCL_DEBUG("call %s done\n", __func__);
51+
}

ggml/src/ggml-sycl/clamp.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef GGML_SYCL_CLAMP_HPP
2+
#define GGML_SYCL_CLAMP_HPP
3+
4+
#include "common.hpp"
5+
6+
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7+
8+
#endif // GGML_SYCL_CLAMP_HPP

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,18 +1463,6 @@ static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
14631463
}
14641464
}
14651465

1466-
static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
1467-
const sycl::nd_item<3> &item_ct1) {
1468-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1469-
item_ct1.get_local_id(2);
1470-
1471-
if (i >= k) {
1472-
return;
1473-
}
1474-
1475-
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
1476-
}
1477-
14781466
template <typename Ti, typename To>
14791467
static void pool2d_nchw_kernel(
14801468
const int ih, const int iw, const int oh, const int ow,
@@ -1600,19 +1588,6 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
16001588
}
16011589
}
16021590

1603-
static void clamp_f32_sycl(const float *x, float *dst, const float min,
1604-
const float max, const int k,
1605-
queue_ptr stream) {
1606-
const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
1607-
stream->parallel_for(
1608-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
1609-
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
1610-
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
1611-
[=](sycl::nd_item<3> item_ct1) {
1612-
clamp_f32(x, dst, min, max, k, item_ct1);
1613-
});
1614-
}
1615-
16161591
static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
16171592
const int nrows, queue_ptr stream) {
16181593
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -1905,28 +1880,6 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
19051880
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
19061881
}
19071882

1908-
inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
1909-
1910-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1911-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
1912-
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));
1913-
1914-
float min;
1915-
float max;
1916-
memcpy(&min, dst->op_params, sizeof(float));
1917-
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
1918-
const dpct::queue_ptr main_stream = ctx.stream();
1919-
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
1920-
float * dst_dd = static_cast<float *>(dst->data);
1921-
1922-
clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(dst->src[0]), main_stream);
1923-
/*
1924-
DPCT1010:88: SYCL uses exceptions to report errors and does not use the
1925-
error codes. The call was replaced with 0. You need to rewrite this code.
1926-
*/
1927-
SYCL_CHECK(0);
1928-
}
1929-
19301883
static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
19311884
static bool peer_access_enabled = false;
19321885

@@ -2848,10 +2801,6 @@ catch (sycl::exception const &exc) {
28482801
std::exit(1);
28492802
}
28502803

2851-
static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2852-
ggml_sycl_op_clamp(ctx, dst);
2853-
}
2854-
28552804
static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
28562805
ggml_sycl_op_pool2d(ctx, dst);
28572806
}

0 commit comments

Comments
 (0)