Skip to content

Commit 818682e

Browse files
authored
[SYCL][ESIMD]Refactor several math functions and remove duplicate constants (#7490)
Complementary test PR: intel/llvm-test-suite#1405
1 parent c1947ff commit 818682e

File tree

1 file changed

+62
-153
lines changed
  • sycl/include/sycl/ext/intel/experimental/esimd

1 file changed

+62
-153
lines changed

sycl/include/sycl/ext/intel/experimental/esimd/math.hpp

Lines changed: 62 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,19 +1152,24 @@ sincos(__ESIMD_NS::simd<float, SZ> &dstcos, U src0, Sat sat = {}) {
11521152

11531153
/// @cond ESIMD_DETAIL
11541154
namespace detail {
1155-
constexpr double HDR_CONST_PI = 3.1415926535897932384626433832795;
1155+
constexpr double __ESIMD_CONST_PI = 3.1415926535897932384626433832795;
11561156
} // namespace detail
11571157
/// @endcond ESIMD_DETAIL
11581158

11591159
template <typename T, int SZ>
1160-
ESIMD_NODEBUG ESIMD_INLINE
1161-
std::enable_if_t<std::is_floating_point<T>::value, __ESIMD_NS::simd<T, SZ>>
1162-
atan(__ESIMD_NS::simd<T, SZ> src0) {
1160+
ESIMD_NODEBUG ESIMD_INLINE __ESIMD_NS::simd<T, SZ>
1161+
atan(__ESIMD_NS::simd<T, SZ> src0) {
1162+
static_assert(std::is_floating_point<T>::value,
1163+
"Floating point argument type is expected.");
11631164
__ESIMD_NS::simd<T, SZ> Src0 = __ESIMD_NS::abs(src0);
11641165

1165-
__ESIMD_NS::simd_mask<SZ> Neg = src0 < T(0.0);
1166+
__ESIMD_NS::simd<T, SZ> OneP((T)1.0);
1167+
__ESIMD_NS::simd<T, SZ> OneN((T)-1.0);
1168+
__ESIMD_NS::simd<T, SZ> sign;
11661169
__ESIMD_NS::simd_mask<SZ> Gt1 = Src0 > T(1.0);
11671170

1171+
sign.merge(OneN, OneP, src0 < 0);
1172+
11681173
Src0.merge(__ESIMD_NS::inv(Src0), Gt1);
11691174

11701175
__ESIMD_NS::simd<T, SZ> Src0P2 = Src0 * Src0;
@@ -1179,13 +1184,14 @@ ESIMD_NODEBUG ESIMD_INLINE
11791184
((Src0 * T(0.395889) + T(1.12158)) * Src0P2) + (Src0 * T(0.636918)) +
11801185
T(1.0));
11811186

1182-
Result.merge(Result - T(detail::HDR_CONST_PI / 2.0), Gt1);
1183-
Result.merge(Result, Neg);
1184-
return Result;
1187+
Result.merge(Result - T(detail::__ESIMD_CONST_PI) / T(2.0), Gt1);
1188+
1189+
return __ESIMD_NS::abs(Result) * sign;
11851190
}
11861191

1187-
template <typename T>
1188-
__ESIMD_API std::enable_if_t<std::is_floating_point<T>::value, T> atan(T src0) {
1192+
template <typename T> __ESIMD_API T atan(T src0) {
1193+
static_assert(std::is_floating_point<T>::value,
1194+
"Floating point argument type is expected.");
11891195
__ESIMD_NS::simd<T, 1> Src0 = src0;
11901196
__ESIMD_NS::simd<T, 1> Result = esimd::atan(Src0);
11911197
return Result[0];
@@ -1218,7 +1224,7 @@ ESIMD_NODEBUG ESIMD_INLINE
12181224
__ESIMD_NS::rsqrt(Src01m * T(2.0));
12191225

12201226
Result.merge(T(0.0), TooBig);
1221-
Result.merge(T(detail::HDR_CONST_PI) - Result, Neg);
1227+
Result.merge(T(detail::__ESIMD_CONST_PI) - Result, Neg);
12221228
return Result;
12231229
}
12241230

@@ -1238,7 +1244,7 @@ ESIMD_NODEBUG ESIMD_INLINE
12381244
__ESIMD_NS::simd_mask<SZ> Neg = src0 < T(0.0);
12391245

12401246
__ESIMD_NS::simd<T, SZ> Result =
1241-
T(detail::HDR_CONST_PI / 2.0) - esimd::acos(__ESIMD_NS::abs(src0));
1247+
T(detail::__ESIMD_CONST_PI / 2.0) - esimd::acos(__ESIMD_NS::abs(src0));
12421248

12431249
Result.merge(-Result, Neg);
12441250
return Result;
@@ -1309,41 +1315,28 @@ template <int N> __ESIMD_NS::simd<float, N> tanh(__ESIMD_NS::simd<float, N> x);
13091315
/* ------------------------- Extended Math Routines
13101316
* -------------------------------------------------*/
13111317

1312-
/// @cond ESIMD_DETAIL
1313-
1314-
namespace detail {
1315-
static auto constexpr CONST_PI = 3.14159f;
1316-
static auto constexpr CMPI = 3.14159265f;
1317-
} // namespace detail
1318-
1319-
/// @endcond ESIMD_DETAIL
1320-
13211318
// For vector input
13221319
template <int N>
13231320
ESIMD_INLINE __ESIMD_NS::simd<float, N>
13241321
atan2_fast(__ESIMD_NS::simd<float, N> y, __ESIMD_NS::simd<float, N> x) {
1325-
__ESIMD_NS::simd<float, N> a0;
1326-
__ESIMD_NS::simd<float, N> a1;
1322+
/* smallest such that 1.0+CONST_DBL_EPSILON != 1.0 */
1323+
constexpr float CONST_DBL_EPSILON = 0.00001f;
1324+
__ESIMD_NS::simd<float, N> OneP(1.0f);
1325+
__ESIMD_NS::simd<float, N> OneN(-1.0f);
1326+
__ESIMD_NS::simd<float, N> sign;
13271327
__ESIMD_NS::simd<float, N> atan2;
1328+
__ESIMD_NS::simd<float, N> r;
1329+
__ESIMD_NS::simd_mask<N> mask = x < 0;
1330+
__ESIMD_NS::simd<float, N> abs_y = __ESIMD_NS::abs(y) + CONST_DBL_EPSILON;
13281331

1329-
__ESIMD_NS::simd_mask<N> mask = (y >= 0.0f);
1330-
a0.merge(detail::CONST_PI * 0.5f, detail::CONST_PI * 1.5f, mask);
1331-
a1.merge(0, detail::CONST_PI * 2.0f, mask);
1332-
1333-
a1.merge(detail::CONST_PI, x < 0.0f);
1334-
1335-
__ESIMD_NS::simd<float, N> xy = x * y;
1336-
__ESIMD_NS::simd<float, N> x2 = x * x;
1337-
__ESIMD_NS::simd<float, N> y2 = y * y;
1338-
1339-
/* smallest such that 1.0+CONST_DBL_EPSILON != 1.0 */
1340-
constexpr auto CONST_DBL_EPSILON = 0.00001f;
1332+
r.merge((x + abs_y) / (abs_y - x), (x - abs_y) / (x + abs_y), mask);
1333+
atan2.merge(float(detail::__ESIMD_CONST_PI) * 0.75f,
1334+
float(detail::__ESIMD_CONST_PI) * 0.25f, mask);
1335+
atan2 += (0.1963f * r * r - 0.9817f) * r;
13411336

1342-
a0 -= (xy / (y2 + x2 * 0.28f + CONST_DBL_EPSILON));
1343-
a1 += (xy / (x2 + y2 * 0.28f + CONST_DBL_EPSILON));
1337+
sign.merge(OneN, OneP, y < 0);
13441338

1345-
atan2.merge(a1, a0, y2 <= x2);
1346-
return atan2;
1339+
return atan2 * sign;
13471340
}
13481341

13491342
// For Scalar Input
@@ -1360,30 +1353,30 @@ template <int N>
13601353
ESIMD_INLINE __ESIMD_NS::simd<float, N> atan2(__ESIMD_NS::simd<float, N> y,
13611354
__ESIMD_NS::simd<float, N> x) {
13621355
__ESIMD_NS::simd<float, N> v_distance;
1363-
__ESIMD_NS::simd<float, N> v_y0;
13641356
__ESIMD_NS::simd<float, N> atan2;
13651357
__ESIMD_NS::simd_mask<N> mask;
13661358

1367-
mask = (x < 0);
1368-
v_y0.merge(detail::CONST_PI, 0, mask);
1359+
constexpr float CONST_DBL_EPSILON = 0.00001f;
1360+
1361+
mask = (x < -CONST_DBL_EPSILON && y < CONST_DBL_EPSILON && y >= 0.f);
1362+
atan2.merge(float(detail::__ESIMD_CONST_PI), 0.f, mask);
1363+
mask = (x < -CONST_DBL_EPSILON && y > -CONST_DBL_EPSILON && y < 0);
1364+
atan2.merge(float(-detail::__ESIMD_CONST_PI), mask);
1365+
mask = (x < CONST_DBL_EPSILON && __ESIMD_NS::abs(y) > CONST_DBL_EPSILON);
13691366
v_distance = __ESIMD_NS::sqrt(x * x + y * y);
1370-
mask = (__ESIMD_NS::abs<float>(y) < 0.000001f);
1371-
atan2.merge(v_y0, (2 * esimd::atan((v_distance - x) / y)), mask);
1367+
atan2.merge(2.0f * esimd::atan((v_distance - x) / y), mask);
1368+
1369+
mask = (x > 0.f);
1370+
atan2.merge(2.0f * esimd::atan(y / (v_distance + x)), mask);
1371+
13721372
return atan2;
13731373
}
13741374

13751375
// For Scalar Input
13761376
template <> ESIMD_INLINE float atan2(float y, float x) {
1377-
float v_distance;
1378-
float v_y0;
1379-
__ESIMD_NS::simd<float, 1> atan2;
1380-
__ESIMD_NS::simd_mask<1> mask;
1381-
1382-
mask = (x < 0);
1383-
v_y0 = mask[0] ? detail::CONST_PI : 0;
1384-
v_distance = __ESIMD_NS::sqrt<float>(x * x + y * y);
1385-
mask = (__ESIMD_NS::abs<float>(y) < 0.000001f);
1386-
atan2.merge(v_y0, (2 * esimd::atan((v_distance - x) / y)), mask);
1377+
__ESIMD_NS::simd<float, 1> vy = y;
1378+
__ESIMD_NS::simd<float, 1> vx = x;
1379+
__ESIMD_NS::simd<float, 1> atan2 = esimd::atan2(vy, vx);
13871380
return atan2[0];
13881381
}
13891382

@@ -1394,6 +1387,7 @@ ESIMD_INLINE __ESIMD_NS::simd<float, N> fmod(__ESIMD_NS::simd<float, N> y,
13941387
__ESIMD_NS::simd<float, N> x) {
13951388
__ESIMD_NS::simd<float, N> abs_x = __ESIMD_NS::abs(x);
13961389
__ESIMD_NS::simd<float, N> abs_y = __ESIMD_NS::abs(y);
1390+
13971391
auto fmod_sign_mask = (y.template bit_cast_view<int32_t>()) & 0x80000000;
13981392

13991393
__ESIMD_NS::simd<float, N> reminder =
@@ -1423,18 +1417,19 @@ ESIMD_INLINE __ESIMD_NS::simd<float, N> sin_emu(__ESIMD_NS::simd<float, N> x) {
14231417

14241418
__ESIMD_NS::simd<float, N> sign;
14251419
__ESIMD_NS::simd<float, N> fTrig;
1426-
__ESIMD_NS::simd<float, N> TwoPI(6.2831853f);
1427-
__ESIMD_NS::simd<float, N> CmpI(detail::CMPI);
1428-
__ESIMD_NS::simd<float, N> OneP(1.f);
1429-
__ESIMD_NS::simd<float, N> OneN(-1.f);
1420+
__ESIMD_NS::simd<float, N> TwoPI(float(detail::__ESIMD_CONST_PI) * 2.0f);
1421+
__ESIMD_NS::simd<float, N> CmpI((float)detail::__ESIMD_CONST_PI);
1422+
__ESIMD_NS::simd<float, N> OneP(1.0f);
1423+
__ESIMD_NS::simd<float, N> OneN(-1.0f);
14301424

14311425
x = esimd::fmod(x, TwoPI);
1426+
x.merge(TwoPI + x, x < 0);
14321427

1433-
x1.merge(CmpI - x, x - CmpI, (x <= detail::CMPI));
1434-
x1.merge(x, (x <= detail::CMPI * 0.5f));
1435-
x1.merge(CmpI * 2 - x, (x > detail::CMPI * 1.5f));
1428+
x1.merge(CmpI - x, x - CmpI, (x <= float(detail::__ESIMD_CONST_PI)));
1429+
x1.merge(x, (x <= float(detail::__ESIMD_CONST_PI) * 0.5f));
1430+
x1.merge(TwoPI - x, (x > float(detail::__ESIMD_CONST_PI) * 1.5f));
14361431

1437-
sign.merge(OneN, OneP, (x > detail::CMPI));
1432+
sign.merge(OneN, OneP, (x > float(detail::__ESIMD_CONST_PI)));
14381433

14391434
x2 = x1 * x1;
14401435
t3 = x2 * x1 * 0.1666667f;
@@ -1449,106 +1444,20 @@ ESIMD_INLINE __ESIMD_NS::simd<float, N> sin_emu(__ESIMD_NS::simd<float, N> x) {
14491444
}
14501445

14511446
// scalar Input
1452-
template <typename T> ESIMD_INLINE float sin_emu(T x0) {
1453-
__ESIMD_NS::simd<float, 1> x1;
1454-
__ESIMD_NS::simd<float, 1> x2;
1455-
__ESIMD_NS::simd<float, 1> t3;
1456-
1457-
__ESIMD_NS::simd<float, 1> sign;
1458-
__ESIMD_NS::simd<float, 1> fTrig;
1459-
float TwoPI = detail::CMPI * 2.0f;
1460-
1461-
__ESIMD_NS::simd<float, 1> x = esimd::fmod(x0, TwoPI);
1462-
1463-
__ESIMD_NS::simd<float, 1> CmpI(detail::CMPI);
1464-
__ESIMD_NS::simd<float, 1> OneP(1.f);
1465-
__ESIMD_NS::simd<float, 1> OneN(-1.f);
1466-
1467-
x1.merge(CmpI - x, x - CmpI, (x <= detail::CMPI));
1468-
x1.merge(x, (x <= detail::CMPI * 0.5f));
1469-
x1.merge(CmpI * 2.0f - x, (x > detail::CMPI * 1.5f));
1470-
1471-
sign.merge(OneN, OneP, (x > detail::CMPI));
1472-
1473-
x2 = x1 * x1;
1474-
t3 = x2 * x1 * 0.1666667f;
1475-
1476-
fTrig =
1477-
x1 + t3 * (OneN + x2 * 0.05f *
1478-
(OneP + x2 * 0.0238095f *
1479-
(OneN + x2 * 0.0138889f *
1480-
(OneP - x2 * 0.0090909f))));
1481-
fTrig *= sign;
1482-
return fTrig[0];
1447+
template <> ESIMD_INLINE float sin_emu(float x0) {
1448+
return esimd::sin_emu(__ESIMD_NS::simd<float, 1>(x0))[0];
14831449
}
14841450

14851451
// cos_emu - EU emulation for sin(x)
14861452
// For Vector input
14871453
template <int N>
14881454
ESIMD_INLINE __ESIMD_NS::simd<float, N> cos_emu(__ESIMD_NS::simd<float, N> x) {
1489-
__ESIMD_NS::simd<float, N> x1;
1490-
__ESIMD_NS::simd<float, N> x2;
1491-
__ESIMD_NS::simd<float, N> t2;
1492-
__ESIMD_NS::simd<float, N> t3;
1493-
1494-
__ESIMD_NS::simd<float, N> sign;
1495-
__ESIMD_NS::simd<float, N> fTrig;
1496-
__ESIMD_NS::simd<float, N> TwoPI(6.2831853f);
1497-
__ESIMD_NS::simd<float, N> CmpI(detail::CMPI);
1498-
__ESIMD_NS::simd<float, N> OneP(1.f);
1499-
__ESIMD_NS::simd<float, N> OneN(-1.f);
1500-
1501-
x = esimd::fmod(x, TwoPI);
1502-
1503-
x1.merge(x - detail::CMPI * 0.5f, CmpI * 1.5f - x, (x <= detail::CMPI));
1504-
x1.merge(CmpI * 0.5f - x, (x <= detail::CMPI * 0.5f));
1505-
x1.merge(x - detail::CMPI * 1.5f, (x > detail::CMPI * 1.5f));
1506-
1507-
sign.merge(1, -1, ((x < detail::CMPI * 0.5f) | (x >= detail::CMPI * 1.5f)));
1508-
1509-
x2 = x1 * x1;
1510-
t3 = x2 * x1 * 0.1666667f;
1511-
fTrig =
1512-
x1 + t3 * (OneN + x2 * 0.05f *
1513-
(OneP + x2 * 0.0238095f *
1514-
(OneN + x2 * 0.0138889f *
1515-
(OneP - x2 * 0.0090909f))));
1516-
fTrig *= sign;
1517-
return fTrig;
1455+
return esimd::sin_emu(0.5f * float(detail::__ESIMD_CONST_PI) - x);
15181456
}
15191457

15201458
// scalar Input
1521-
template <typename T> ESIMD_INLINE float cos_emu(T x0) {
1522-
__ESIMD_NS::simd<float, 1> x1;
1523-
__ESIMD_NS::simd<float, 1> x2;
1524-
__ESIMD_NS::simd<float, 1> t3;
1525-
1526-
__ESIMD_NS::simd<float, 1> sign;
1527-
__ESIMD_NS::simd<float, 1> fTrig;
1528-
float TwoPI = detail::CMPI * 2.0f;
1529-
1530-
__ESIMD_NS::simd<float, 1> x = esimd::fmod(x0, TwoPI);
1531-
1532-
__ESIMD_NS::simd<float, 1> CmpI(detail::CMPI);
1533-
__ESIMD_NS::simd<float, 1> OneP(1.f);
1534-
__ESIMD_NS::simd<float, 1> OneN(-1.f);
1535-
1536-
x1.merge(x - detail::CMPI * 0.5f, CmpI * 1.5f - x, (x <= detail::CMPI));
1537-
x1.merge(CmpI * 0.5f - x, (x <= detail::CMPI * 0.5f));
1538-
x1.merge(x - detail::CMPI * 1.5f, (x > detail::CMPI * 1.5f));
1539-
1540-
sign.merge(OneP, OneN,
1541-
((x < detail::CMPI * 0.5f) | (x >= detail::CMPI * 1.5f)));
1542-
1543-
x2 = x1 * x1;
1544-
t3 = x2 * x1 * 0.1666667f;
1545-
fTrig =
1546-
x1 + t3 * (OneN + x2 * 0.05f *
1547-
(OneP + x2 * 0.0238095f *
1548-
(OneN + x2 * 0.0138889f *
1549-
(OneP - x2 * 0.0090909f))));
1550-
fTrig *= sign;
1551-
return fTrig[0];
1459+
template <> ESIMD_INLINE float cos_emu(float x0) {
1460+
return esimd::cos_emu(__ESIMD_NS::simd<float, 1>(x0))[0];
15521461
}
15531462

15541463
/// @cond ESIMD_DETAIL

0 commit comments

Comments
 (0)