Skip to content

Commit 8b1894b

Browse files
hidekisaitovladimirlaz
authored andcommitted
[SYCL] Improve SYCL vector implementation
- Improve convert(), load(), store() methods. - Handle special case for char vectors on host. - Improve comments. - Add long long support. Signed-off-by: Hideki Saito <[email protected]> Signed-off-by: Vladimir Lazarev <[email protected]>
1 parent e721955 commit 8b1894b

File tree

1 file changed

+156
-30
lines changed

1 file changed

+156
-30
lines changed

sycl/include/CL/sycl/types.hpp

Lines changed: 156 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,46 @@
77
//
88
//===----------------------------------------------------------------------===//
99

10+
// Implements vec and __swizzled_vec__ classes.
11+
1012
#pragma once
1113

14+
// Define __NO_EXT_VECTOR_TYPE_ON_HOST__ to avoid using ext_vector_type
15+
// extension even if the host compiler supports it. The same can be
16+
// accomplished by -D__NO_EXT_VECTOR_TYPE_ON_HOST__ command line option.
17+
#ifndef __NO_EXT_VECTOR_TYPE_ON_HOST__
18+
// #define __NO_EXT_VECTOR_TYPE_ON_HOST__
19+
#endif
20+
21+
// Check if Clang's ext_vector_type attribute is available. Host compiler
22+
// may not be Clang, and Clang may not be built with the extension.
23+
#ifdef __clang__
24+
#ifndef __has_extension
25+
#define __has_extension(x) 0
26+
#endif
27+
#ifdef __HAS_EXT_VECTOR_TYPE__
28+
#error "Undefine __HAS_EXT_VECTOR_TYPE__ macro"
29+
#endif
30+
#if __has_extension(attribute_ext_vector_type)
31+
#define __HAS_EXT_VECTOR_TYPE__
32+
#endif
33+
#endif // __clang__
34+
35+
#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
36+
#error "Undefine __SYCL_USE_EXT_VECTOR_TYPE__ macro"
37+
#endif
38+
#ifdef __HAS_EXT_VECTOR_TYPE__
39+
#if defined(__SYCL_DEVICE_ONLY__) || !defined(__NO_EXT_VECTOR_TYPE_ON_HOST__)
40+
#define __SYCL_USE_EXT_VECTOR_TYPE__
41+
#endif
42+
#elif defined(__SYCL_DEVICE_ONLY__)
43+
// This is a soft error. We expect the device compiler to have ext_vector_type
44+
// support, but that should not be a hard requirement.
45+
#error "SYCL device compiler is built without ext_vector_type support"
46+
#endif // __HAS_EXT_VECTOR_TYPE__
47+
1248
#include <CL/sycl/detail/common.hpp>
1349

14-
#ifndef __SYCL_DEVICE_ONLY__
15-
#include <algorithm>
16-
#include <functional>
17-
#endif // __SYCL_DEVICE_ONLY__
1850
// 4.10.1: Scalar data types
1951
// 4.10.2: SYCL vector types
2052

@@ -183,6 +215,19 @@ template <typename T> struct LShift {
183215
}
184216
};
185217

218+
template <typename T, typename convertT, rounding_mode roundingMode>
219+
T convertHelper(const T &Opnd) {
220+
if (roundingMode == rounding_mode::automatic ||
221+
roundingMode == rounding_mode::rtz) {
222+
return static_cast<convertT>(Opnd);
223+
}
224+
if (roundingMode == rounding_mode::rtp) {
225+
return static_cast<convertT>(ceil(Opnd));
226+
}
227+
// roundingMode == rounding_mode::rtn
228+
return static_cast<convertT>(floor(Opnd));
229+
}
230+
186231
} // namespace detail
187232

188233
template <typename DataT, int NumElements> class vec {
@@ -338,7 +383,7 @@ template <typename DataT, int NumElements> class vec {
338383
void dump() {
339384
#ifndef __SYCL_DEVICE_ONLY__
340385
for (int I = 0; I < NumElements; ++I) {
341-
std::cout << " " << I << ": " << m_Data.s[I] << std::endl;
386+
std::cout << " " << I << ": " << getValue(I) << std::endl;
342387
}
343388
std::cout << std::endl;
344389
#endif // __SYCL_DEVICE_ONLY__
@@ -361,12 +406,20 @@ template <typename DataT, int NumElements> class vec {
361406
size_t get_count() const { return NumElements; }
362407
size_t get_size() const { return sizeof(m_Data); }
363408

364-
// TODO: convert() for FP types. Also, check whether rounding mode handling
409+
// TODO: convert() for FP to FP. Also, check whether rounding mode handling
365410
// is needed for integers to FP convert.
366-
// template <typename convertT, rounding_mode roundingMode>
367-
// vec<convertT, NumElements> convert() const;
411+
//
412+
// Convert to same type is no-op.
413+
template <typename convertT, rounding_mode roundingMode>
414+
typename std::enable_if<std::is_same<DataT, convertT>::value,
415+
vec<convertT, NumElements>>::type
416+
convert() const {
417+
return *this;
418+
}
419+
// From Integer to Integer or FP
368420
template <typename convertT, rounding_mode roundingMode>
369-
typename std::enable_if<std::is_integral<DataT>::value,
421+
typename std::enable_if<!std::is_same<DataT, convertT>::value &&
422+
std::is_integral<DataT>::value,
370423
vec<convertT, NumElements>>::type
371424
convert() const {
372425
vec<convertT, NumElements> Result;
@@ -375,6 +428,20 @@ template <typename DataT, int NumElements> class vec {
375428
}
376429
return Result;
377430
}
431+
// From FP to Integer
432+
template <typename convertT, rounding_mode roundingMode>
433+
typename std::enable_if<!std::is_same<DataT, convertT>::value &&
434+
std::is_integral<convertT>::value &&
435+
std::is_floating_point<DataT>::value,
436+
vec<convertT, NumElements>>::type
437+
convert() const {
438+
vec<convertT, NumElements> Result;
439+
for (size_t I = 0; I < NumElements; ++I) {
440+
Result.setValue(
441+
I, detail::convertHelper<convertT, roundingMode>(getValue(I)));
442+
}
443+
return Result;
444+
}
378445

379446
template <typename asT>
380447
typename std::enable_if<sizeof(asT) == sizeof(DataType), asT>::type
@@ -415,12 +482,24 @@ template <typename DataT, int NumElements> class vec {
415482
#endif
416483
#define __SYCL_LOADSTORE(Space) \
417484
void load(size_t Offset, multi_ptr<DataT, Space> Ptr) { \
418-
m_Data = *multi_ptr<DataType, Space>(static_cast<DataType *>( \
419-
static_cast<void *>(Ptr + Offset * NumElements))); \
485+
if (NumElements != 3) { \
486+
m_Data = *multi_ptr<DataType, Space>(static_cast<DataType *>( \
487+
static_cast<void *>(Ptr + Offset * NumElements))); \
488+
return; \
489+
} \
490+
for (int I = 0; I < NumElements; I++) { \
491+
setValue(I, *multi_ptr<DataT, Space>(Ptr + Offset * NumElements + I)); \
492+
} \
420493
} \
421494
void store(size_t Offset, multi_ptr<DataT, Space> Ptr) const { \
422-
*multi_ptr<DataType, Space>(static_cast<DataType *>( \
423-
static_cast<void *>(Ptr + Offset * NumElements))) = m_Data; \
495+
if (NumElements != 3) { \
496+
*multi_ptr<DataType, Space>(static_cast<DataType *>( \
497+
static_cast<void *>(Ptr + Offset * NumElements))) = m_Data; \
498+
return; \
499+
} \
500+
for (int I = 0; I < NumElements; I++) { \
501+
*multi_ptr<DataT, Space>(Ptr + Offset * NumElements + I) = getValue(I); \
502+
} \
424503
}
425504

426505
__SYCL_LOADSTORE(access::address_space::global_space)
@@ -433,7 +512,7 @@ template <typename DataT, int NumElements> class vec {
433512
#error "Undefine __SYCL_BINOP macro"
434513
#endif
435514

436-
#ifdef __SYCL_DEVICE_ONLY__
515+
#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
437516
#define __SYCL_BINOP(BINOP, OPASSIGN) \
438517
vec operator BINOP(const vec &Rhs) const { \
439518
vec Ret; \
@@ -457,7 +536,7 @@ template <typename DataT, int NumElements> class vec {
457536
*this = *this BINOP vec(Rhs); \
458537
return *this; \
459538
}
460-
#else // __SYCL_DEVICE_ONLY__
539+
#else // __SYCL_USE_EXT_VECTOR_TYPE__
461540
#define __SYCL_BINOP(BINOP, OPASSIGN) \
462541
vec operator BINOP(const vec &Rhs) const { \
463542
vec Ret; \
@@ -483,7 +562,7 @@ template <typename DataT, int NumElements> class vec {
483562
*this = *this BINOP vec(Rhs); \
484563
return *this; \
485564
}
486-
#endif // __SYCL_DEVICE_ONLY__
565+
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
487566

488567
__SYCL_BINOP(+, +=)
489568
__SYCL_BINOP(-, -=)
@@ -588,21 +667,21 @@ template <typename DataT, int NumElements> class vec {
588667
vec<DataT, NumElements>
589668
operatorHelper(const vec<DataT, NumElements> &Rhs) const {
590669
vec<DataT, NumElements> Result;
591-
#ifdef __SYCL_DEVICE_ONLY__
670+
#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
592671
Operation<DataType> Op;
593672
Result.m_Data = Op(m_Data, Rhs.m_Data);
594-
#else // __SYCL_DEVICE_ONLY__
673+
#else // __SYCL_USE_EXT_VECTOR_TYPE__
595674
Operation<DataT> Op;
596675
for (size_t I = 0; I < NumElements; ++I) {
597676
Result.setValue(I, Op(Rhs.getValue(I), getValue(I)));
598677
}
599-
#endif // __SYCL_DEVICE_ONLY__
678+
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
600679
return Result;
601680
}
602681

603682
// setValue and getValue should be able to operate on different underlying
604683
// types: enum cl_float#N , builtin vector float#N, builtin type float.
605-
#ifdef __SYCL_DEVICE_ONLY__
684+
#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
606685
template <int Num = NumElements,
607686
typename = typename std::enable_if<1 != Num>::type>
608687
void setValue(int Index, const DataT &Value, int) {
@@ -614,7 +693,7 @@ template <typename DataT, int NumElements> class vec {
614693
DataT getValue(int Index, int) const {
615694
return m_Data[Index];
616695
}
617-
#else
696+
#else // __SYCL_USE_EXT_VECTOR_TYPE__
618697
template <int Num = NumElements,
619698
typename = typename std::enable_if<1 != Num>::type>
620699
void setValue(int Index, const DataT &Value, int) {
@@ -626,7 +705,7 @@ template <typename DataT, int NumElements> class vec {
626705
DataT getValue(int Index, int) const {
627706
return m_Data.s[Index];
628707
}
629-
#endif
708+
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
630709

631710
template <int Num = NumElements,
632711
typename = typename std::enable_if<1 == Num>::type>
@@ -1275,7 +1354,7 @@ class SwizzleOp {
12751354
template <typename T, int Num> \
12761355
typename std::enable_if<std::is_fundamental<T>::value, vec<T, Num>>::type \
12771356
operator BINOP(const T &Lhs, const vec<T, Num> &Rhs) { \
1278-
return vec<T, Num>(static_cast<T>(Lhs)) BINOP Rhs; \
1357+
return vec<T, Num>(Lhs) BINOP Rhs; \
12791358
} \
12801359
template <typename VecT, typename OperationLeftT, typename OperationRightT, \
12811360
template <typename> class OperationCurrentT, int... Indexes, \
@@ -1367,7 +1446,7 @@ __SYCL_RELLOGOP(||)
13671446
} // namespace sycl
13681447
} // namespace cl
13691448

1370-
#ifdef __SYCL_DEVICE_ONLY__
1449+
#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
13711450
typedef char __char_t;
13721451
typedef char __char2_vec_t __attribute__((ext_vector_type(2)));
13731452
typedef char __char3_vec_t __attribute__((ext_vector_type(3)));
@@ -1461,7 +1540,7 @@ typedef double __double16_vec_t __attribute__((ext_vector_type(16)));
14611540

14621541
#define GET_CL_TYPE(target, num) __##target##num##_vec_t
14631542
#define GET_SCALAR_CL_TYPE(target) target
1464-
#else // __SYCL_DEVICE_ONLY__
1543+
#else // __SYCL_USE_EXT_VECTOR_TYPE__
14651544
// For signed char. OpenCL doesn't have any type about `signed char`, therefore
14661545
// we use type alias of cl_char instead.
14671546
using cl_schar = cl_char;
@@ -1473,7 +1552,7 @@ using cl_schar16 = cl_char16;
14731552

14741553
#define GET_CL_TYPE(target, num) cl_##target##num
14751554
#define GET_SCALAR_CL_TYPE(target) cl_##target
1476-
#endif // __SYCL_DEVICE_ONLY__
1555+
#endif // __SYCL_USE_EXT_VECTOR_TYPE__
14771556

14781557
namespace cl {
14791558
namespace sycl {
@@ -1484,6 +1563,12 @@ namespace sycl {
14841563
using DataType = GET_CL_TYPE(base, num); \
14851564
};
14861565

1566+
#define DECLARE_LONGLONG_CONVERTER(base, num) \
1567+
template <> class BaseCLTypeConverter<base##long, num> { \
1568+
public: \
1569+
using DataType = ::GET_CL_TYPE(base, num); \
1570+
};
1571+
14871572
#define DECLARE_VECTOR_CONVERTERS(base) \
14881573
namespace detail { \
14891574
DECLARE_CONVERTER(base, 2) \
@@ -1497,6 +1582,19 @@ namespace sycl {
14971582
}; \
14981583
} // namespace detail
14991584

1585+
#define DECLARE_VECTOR_LONGLONG_CONVERTERS(base) \
1586+
namespace detail { \
1587+
DECLARE_LONGLONG_CONVERTER(base, 2) \
1588+
DECLARE_LONGLONG_CONVERTER(base, 3) \
1589+
DECLARE_LONGLONG_CONVERTER(base, 4) \
1590+
DECLARE_LONGLONG_CONVERTER(base, 8) \
1591+
DECLARE_LONGLONG_CONVERTER(base, 16) \
1592+
template <> class BaseCLTypeConverter<base##long, 1> { \
1593+
public: \
1594+
using DataType = GET_SCALAR_CL_TYPE(base); \
1595+
}; \
1596+
} // namespace detail
1597+
15001598
#define DECLARE_SYCL_VEC_WO_CONVERTERS(base) \
15011599
using cl_##base##16 = vec<base, 16>; \
15021600
using cl_##base##8 = vec<base, 8>; \
@@ -1510,11 +1608,40 @@ namespace sycl {
15101608
using base##3 = cl_##base##3; \
15111609
using base##2 = cl_##base##2;
15121610

1611+
#define DECLARE_SYCL_VEC_CHAR_WO_CONVERTERS \
1612+
using cl_char16 = vec<signed char, 16>; \
1613+
using cl_char8 = vec<signed char, 8>; \
1614+
using cl_char4 = vec<signed char, 4>; \
1615+
using cl_char3 = vec<signed char, 3>; \
1616+
using cl_char2 = vec<signed char, 2>; \
1617+
using cl_char = signed char; \
1618+
using char16 = vec<char, 16>; \
1619+
using char8 = vec<char, 8>; \
1620+
using char4 = vec<char, 4>; \
1621+
using char3 = vec<char, 3>; \
1622+
using char2 = vec<char, 2>;
1623+
1624+
// cl_longlong/cl_ulonglong are not supported in SYCL
1625+
#define DECLARE_SYCL_VEC_LONGLONG_WO_CONVERTERS(base) \
1626+
using base##long16 = vec<base##long, 16>; \
1627+
using base##long8 = vec<base##long, 8>; \
1628+
using base##long4 = vec<base##long, 4>; \
1629+
using base##long3 = vec<base##long, 3>; \
1630+
using base##long2 = vec<base##long, 2>;
1631+
15131632
#define DECLARE_SYCL_VEC(base) \
15141633
DECLARE_VECTOR_CONVERTERS(base) \
15151634
DECLARE_SYCL_VEC_WO_CONVERTERS(base)
15161635

1517-
DECLARE_SYCL_VEC(char)
1636+
#define DECLARE_SYCL_VEC_CHAR \
1637+
DECLARE_VECTOR_CONVERTERS(char) \
1638+
DECLARE_SYCL_VEC_CHAR_WO_CONVERTERS
1639+
1640+
#define DECLARE_SYCL_VEC_LONGLONG(base) \
1641+
DECLARE_VECTOR_LONGLONG_CONVERTERS(base) \
1642+
DECLARE_SYCL_VEC_LONGLONG_WO_CONVERTERS(base)
1643+
1644+
DECLARE_SYCL_VEC_CHAR
15181645
DECLARE_SYCL_VEC(schar)
15191646
DECLARE_SYCL_VEC(uchar)
15201647
DECLARE_SYCL_VEC(short)
@@ -1523,9 +1650,8 @@ DECLARE_SYCL_VEC(int)
15231650
DECLARE_SYCL_VEC(uint)
15241651
DECLARE_SYCL_VEC(long)
15251652
DECLARE_SYCL_VEC(ulong)
1526-
// TODO: Fix long long and unsigned long long.
1527-
// DECLARE_SYCL_VEC(longlong)
1528-
// DECLARE_SYCL_VEC(ulonglong)
1653+
DECLARE_SYCL_VEC_LONGLONG(long)
1654+
DECLARE_SYCL_VEC_LONGLONG(ulong)
15291655
DECLARE_SYCL_VEC(float)
15301656
DECLARE_SYCL_VEC(double)
15311657
// TODO: Fix half.

0 commit comments

Comments
 (0)