Skip to content

Commit 5a31917

Browse files
committed
scale: move to a separate file
1 parent e139b9c commit 5a31917

File tree

4 files changed

+57
-49
lines changed

4 files changed

+57
-49
lines changed

ggml/src/ggml-sycl/backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "cpy.hpp"
3636
#include "getrows.hpp"
3737
#include "diagmask.hpp"
38+
#include "scale.hpp"
3839
#include "gla.hpp"
3940

4041
#endif // GGML_SYCL_BACKEND_HPP

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,18 +1463,6 @@ static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
14631463
}
14641464
}
14651465

1466-
static void scale_f32(const float * x, float * dst, const float scale, const int k,
1467-
const sycl::nd_item<3> &item_ct1) {
1468-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1469-
item_ct1.get_local_id(2);
1470-
1471-
if (i >= k) {
1472-
return;
1473-
}
1474-
1475-
dst[i] = scale * x[i];
1476-
}
1477-
14781466
static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
14791467
const sycl::nd_item<3> &item_ct1) {
14801468
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@@ -1612,18 +1600,6 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
16121600
}
16131601
}
16141602

1615-
static void scale_f32_sycl(const float *x, float *dst, const float scale,
1616-
const int k, queue_ptr stream) {
1617-
const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
1618-
stream->parallel_for(
1619-
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
1620-
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
1621-
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
1622-
[=](sycl::nd_item<3> item_ct1) {
1623-
scale_f32(x, dst, scale, k, item_ct1);
1624-
});
1625-
}
1626-
16271603
static void clamp_f32_sycl(const float *x, float *dst, const float min,
16281604
const float max, const int k,
16291605
queue_ptr stream) {
@@ -1929,27 +1905,6 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
19291905
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
19301906
}
19311907

1932-
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1933-
1934-
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1935-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
1936-
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));
1937-
1938-
float scale;
1939-
memcpy(&scale, dst->op_params, sizeof(float));
1940-
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
1941-
float * dst_dd = static_cast<float *>(dst->data);
1942-
1943-
dpct::queue_ptr main_stream = ctx.stream();
1944-
1945-
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
1946-
/*
1947-
DPCT1010:87: SYCL uses exceptions to report errors and does not use the
1948-
error codes. The call was replaced with 0. You need to rewrite this code.
1949-
*/
1950-
SYCL_CHECK(0);
1951-
}
1952-
19531908
inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
19541909

19551910
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
@@ -2893,10 +2848,6 @@ catch (sycl::exception const &exc) {
28932848
std::exit(1);
28942849
}
28952850

2896-
static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2897-
ggml_sycl_op_scale(ctx, dst);
2898-
}
2899-
29002851
static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
29012852
ggml_sycl_op_clamp(ctx, dst);
29022853
}

ggml/src/ggml-sycl/scale.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include "scale.hpp"
2+
3+
static void scale_f32(const float * x, float * dst, const float scale, const int k, const sycl::nd_item<3> & item_ct1) {
4+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
5+
6+
if (i >= k) {
7+
return;
8+
}
9+
10+
dst[i] = scale * x[i];
11+
}
12+
13+
static void scale_f32_sycl(const float * x, float * dst, const float scale, const int k, queue_ptr stream) {
14+
const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
15+
stream->parallel_for(
16+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
17+
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
18+
[=](sycl::nd_item<3> item_ct1) { scale_f32(x, dst, scale, k, item_ct1); });
19+
}
20+
21+
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
22+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
23+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
24+
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
25+
26+
float scale;
27+
memcpy(&scale, dst->op_params, sizeof(float));
28+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
29+
float * dst_dd = static_cast<float *>(dst->data);
30+
31+
dpct::queue_ptr main_stream = ctx.stream();
32+
33+
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
34+
/*
35+
DPCT1010:87: SYCL uses exceptions to report errors and does not use the
36+
error codes. The call was replaced with 0. You need to rewrite this code.
37+
*/
38+
SYCL_CHECK(0);
39+
} catch (const sycl::exception & exc) {
40+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
41+
std::exit(1);
42+
}
43+
44+
void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
45+
GGML_SYCL_DEBUG("call %s\n", __func__);
46+
ggml_sycl_op_scale(ctx, dst);
47+
GGML_SYCL_DEBUG("call %s done\n", __func__);
48+
}

ggml/src/ggml-sycl/scale.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef GGML_SYCL_SCALE_HPP
2+
#define GGML_SYCL_SCALE_HPP
3+
4+
#include "common.hpp"
5+
6+
void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7+
8+
#endif // GGML_SYCL_SCALE_HPP

0 commit comments

Comments
 (0)