Skip to content

Commit 6f2b99b

Browse files
committed
Use static_cast directly
1 parent aadd688 commit 6f2b99b

File tree

2 files changed

+11
-17
lines changed

2 files changed

+11
-17
lines changed

ggml/src/ggml-sycl/element_wise.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne,
2424
template<typename T>
2525
static void gelu(const T * x, T * dst, const int k,
2626
const sycl::nd_item<3> &item_ct1) {
27-
const T GELU_COEF_A = to_T<T>(0.044715f);
28-
const T SQRT_2_OVER_PI = to_T<T>(0.79788456080286535587989211986876f);
27+
const T GELU_COEF_A = static_cast<T>(0.044715f);
28+
const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
2929
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
3030
item_ct1.get_local_id(2);
3131

@@ -34,9 +34,9 @@ static void gelu(const T * x, T * dst, const int k,
3434
}
3535

3636
float xi = x[i];
37-
dst[i] = to_T<T>(0.5f) * xi *
38-
(to_T<T>(1.0f) +
39-
sycl::tanh(SQRT_2_OVER_PI * xi * (to_T<T>(1.0f) + GELU_COEF_A * xi * xi)));
37+
dst[i] = static_cast<T>(0.5f) * xi *
38+
(static_cast<T>(1.0f) +
39+
sycl::tanh(SQRT_2_OVER_PI * xi * (static_cast<T>(1.0f) + GELU_COEF_A * xi * xi)));
4040
}
4141

4242
template<typename T>
@@ -48,7 +48,7 @@ static void silu(const T * x, T * dst, const int k,
4848
if (i >= k) {
4949
return;
5050
}
51-
dst[i] = x[i] / (to_T<T>(1.0f) + sycl::native::exp(-x[i]));
51+
dst[i] = x[i] / (static_cast<T>(1.0f) + sycl::native::exp(-x[i]));
5252
}
5353

5454
template<typename T>
@@ -60,7 +60,7 @@ static void gelu_quick(const T *x, T *dst, int k,
6060
if (i >= k) {
6161
return;
6262
}
63-
dst[i] = x[i] * (to_T<T>(1.0f) / (to_T<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i])));
63+
dst[i] = x[i] * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i])));
6464
}
6565

6666
template<typename T>
@@ -95,7 +95,7 @@ static void sigmoid(const T * x, T * dst, const int k,
9595
if (i >= k) {
9696
return;
9797
}
98-
dst[i] = 1.0f / (to_T<T>(1.0f) + sycl::native::exp(-x[i]));
98+
dst[i] = 1.0f / (static_cast<T>(1.0f) + sycl::native::exp(-x[i]));
9999
}
100100

101101
template<typename T>
@@ -143,7 +143,7 @@ static void hardsigmoid(const T * x, T * dst, const int k,
143143
if (i >= k) {
144144
return;
145145
}
146-
dst[i] = sycl::fmin(to_T<T>(1.0f), sycl::fmax(to_T<T>(0.0f), (x[i] + to_T<T>(3.0f)) / to_T<T>(6.0f)));
146+
dst[i] = sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x[i] + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
147147
}
148148

149149
template<typename T>
@@ -155,7 +155,7 @@ static void hardswish(const T * x, T * dst, const int k,
155155
if (i >= k) {
156156
return;
157157
}
158-
dst[i] = x[i] * sycl::fmin(to_T<T>(1.0f), sycl::fmax(to_T<T>(0.0f), (x[i] + to_T<T>(3.0f)) / to_T<T>(6.0f)));
158+
dst[i] = x[i] * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x[i] + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
159159
}
160160

161161
template<typename T>
@@ -276,7 +276,7 @@ static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne
276276
item_ct1.get_group(0) * ne00 * ne01;
277277
dst[offset_dst] = x[offset_src];
278278
} else {
279-
dst[offset_dst] = to_T<T>(0.0f);
279+
dst[offset_dst] = static_cast<T>(0.0f);
280280
}
281281
}
282282

ggml/src/ggml-sycl/element_wise.hpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,6 @@ T neg_infinity() {
99
return -std::numeric_limits<T>::infinity();
1010
}
1111

12-
template <typename T>
13-
constexpr T to_T(float value) {
14-
return static_cast<T>(value);
15-
}
16-
17-
1812
static __dpct_inline__ float op_repeat(const float a, const float b) {
1913
return b;
2014
GGML_UNUSED(a);

0 commit comments

Comments
 (0)