Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit b428b17

Browse files
[ESIMD] Enable math functions: cos,sin,exp,log
1 parent 0d6e673 commit b428b17

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

SYCL/ESIMD/ext_math.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <CL/sycl.hpp>
1818
#include <CL/sycl/INTEL/esimd.hpp>
19+
#include <CL/sycl/builtins_esimd.hpp>
1920
#include <iostream>
2021

2122
using namespace cl::sycl;
@@ -35,7 +36,16 @@ struct InitDataFuncWide {
3536
struct InitDataFuncNarrow {
3637
void operator()(float *In, float *Out, size_t Size) const {
3738
for (auto I = 0; I < Size; ++I) {
38-
In[I] = 2.0f + 16.0f * ((float)I / (float)(Size - 1)); // in [2..16] range
39+
In[I] = 2.0f + 16.0f * ((float)I / (float)(Size - 1)); // in [2..18] range
40+
Out[I] = (float)0.0;
41+
}
42+
}
43+
};
44+
45+
struct InitDataInRange0_5 {
46+
void operator()(float *In, float *Out, size_t Size) const {
47+
for (auto I = 0; I < Size; ++I) {
48+
In[I] = 5.0f * ((float)I / (float)(Size - 1)); // in [0..5] range
3949
Out[I] = (float)0.0;
4050
}
4151
}
@@ -61,10 +71,19 @@ template <MathOp Op> float HostMathFunc(float X);
6171
} \
6272
}
6373

64-
DEFINE_OP(sin, sin);
65-
DEFINE_OP(cos, cos);
66-
DEFINE_OP(exp, exp);
67-
DEFINE_OP(log, log);
74+
#define DEFINE_OP_REUSE_SCALAR(Op, HostOp) \
75+
template <> float HostMathFunc<MathOp::Op>(float X) { return HostOp(X); } \
76+
template <int VL> struct DeviceMathFunc<VL, MathOp::Op> { \
77+
simd<float, VL> \
78+
operator()(const simd<float, VL> &X) const SYCL_ESIMD_FUNCTION { \
79+
return sycl::Op<VL>(X); \
80+
} \
81+
}
82+
83+
DEFINE_OP_REUSE_SCALAR(sin, sin);
84+
DEFINE_OP_REUSE_SCALAR(cos, cos);
85+
DEFINE_OP_REUSE_SCALAR(exp, exp);
86+
DEFINE_OP_REUSE_SCALAR(log, log);
6887
DEFINE_OP(inv, 1.0f /);
6988
DEFINE_OP(sqrt, sqrt);
7089
DEFINE_OP(rsqrt, 1.0f / sqrt);
@@ -159,13 +178,10 @@ template <int VL> bool test(queue &Q) {
159178
Pass &= test<MathOp::sqrt, VL>(Q, "sqrt", InitDataFuncWide{});
160179
Pass &= test<MathOp::inv, VL>(Q, "inv");
161180
Pass &= test<MathOp::rsqrt, VL>(Q, "rsqrt");
162-
// TODO enable these tests after the implementation is fixed
163-
#if ENABLE_SIN_COS_EXP_LOG
164181
Pass &= test<MathOp::sin, VL>(Q, "sin", InitDataFuncWide{});
165182
Pass &= test<MathOp::cos, VL>(Q, "cos", InitDataFuncWide{});
166-
Pass &= test<MathOp::exp, VL>(Q, "exp");
183+
Pass &= test<MathOp::exp, VL>(Q, "exp", InitDataInRange0_5{});
167184
Pass &= test<MathOp::log, VL>(Q, "log", InitDataFuncWide{});
168-
#endif
169185
return Pass;
170186
}
171187

0 commit comments

Comments
 (0)