Skip to content

Commit 096d0a0

Browse files
authored
[SYCL] Added support of rounding modes for floating and integer types (#1576)
Signed-off-by: Aleksander Fadeev <[email protected]>
1 parent 875347a commit 096d0a0

File tree

7 files changed

+515
-163
lines changed

7 files changed

+515
-163
lines changed

sycl/include/CL/sycl/types.hpp

Lines changed: 188 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,32 @@ using is_int_to_int =
199199
std::integral_constant<bool, std::is_integral<T>::value &&
200200
std::is_integral<R>::value>;
201201

202+
template <typename T, typename R>
203+
using is_sint_to_sint =
204+
std::integral_constant<bool, is_sigeninteger<T>::value &&
205+
is_sigeninteger<R>::value>;
206+
207+
template <typename T, typename R>
208+
using is_uint_to_uint =
209+
std::integral_constant<bool, is_sugeninteger<T>::value &&
210+
is_sugeninteger<R>::value>;
211+
212+
template <typename T, typename R>
213+
using is_sint_to_from_uint = std::integral_constant<
214+
bool, is_sugeninteger<T>::value && is_sigeninteger<R>::value ||
215+
is_sigeninteger<T>::value && is_sugeninteger<R>::value>;
216+
217+
template <typename T, typename R>
218+
using is_sint_to_float =
219+
std::integral_constant<bool, std::is_integral<T>::value &&
220+
!(std::is_unsigned<T>::value) &&
221+
detail::is_floating_point<R>::value>;
222+
223+
template <typename T, typename R>
224+
using is_uint_to_float =
225+
std::integral_constant<bool, std::is_unsigned<T>::value &&
226+
detail::is_floating_point<R>::value>;
227+
202228
template <typename T, typename R>
203229
using is_int_to_float =
204230
std::integral_constant<bool, std::is_integral<T>::value &&
@@ -213,15 +239,23 @@ template <typename T, typename R>
213239
using is_float_to_float =
214240
std::integral_constant<bool, detail::is_floating_point<T>::value &&
215241
detail::is_floating_point<R>::value>;
242+
template <typename T>
243+
using is_standard_type = std::integral_constant<
244+
bool, detail::is_sgentype<T>::value && !std::is_same<T, long long>::value &&
245+
!std::is_same<T, unsigned long long>::value>;
216246

217-
template <typename T, typename R, rounding_mode roundingMode>
247+
template <typename T, typename R, rounding_mode roundingMode, typename OpenCLT,
248+
typename OpenCLR>
218249
detail::enable_if_t<std::is_same<T, R>::value, R> convertImpl(T Value) {
219250
return Value;
220251
}
221252

253+
#ifndef __SYCL_DEVICE_ONLY__
254+
222255
// Note for float to half conversions, static_cast calls the conversion operator
223256
// implemented for host that takes care of the precision requirements.
224-
template <typename T, typename R, rounding_mode roundingMode>
257+
template <typename T, typename R, rounding_mode roundingMode, typename OpenCLT,
258+
typename OpenCLR>
225259
detail::enable_if_t<!std::is_same<T, R>::value &&
226260
(is_int_to_int<T, R>::value ||
227261
is_int_to_float<T, R>::value ||
@@ -231,9 +265,9 @@ convertImpl(T Value) {
231265
return static_cast<R>(Value);
232266
}
233267

234-
#ifndef __SYCL_DEVICE_ONLY__
235268
// float to int
236-
template <typename T, typename R, rounding_mode roundingMode>
269+
template <typename T, typename R, rounding_mode roundingMode, typename OpenCLT,
270+
typename OpenCLR>
237271
detail::enable_if_t<is_float_to_int<T, R>::value, R> convertImpl(T Value) {
238272
switch (roundingMode) {
239273
// Round to nearest even is default rounding mode for floating-point types
@@ -280,16 +314,145 @@ using Rtp = detail::bool_constant<Mode == rounding_mode::rtp>;
280314
template <rounding_mode Mode>
281315
using Rtn = detail::bool_constant<Mode == rounding_mode::rtn>;
282316

283-
// Convert floating-point type to integer type
317+
// convert signed and unsigned types with an equal size and diff names
318+
template <typename T, typename R, rounding_mode roundingMode, typename OpenCLT,
319+
typename OpenCLR>
320+
detail::enable_if_t<!std::is_same<T, R>::value &&
321+
(is_sint_to_sint<T, R>::value ||
322+
is_uint_to_uint<T, R>::value) &&
323+
std::is_same<OpenCLT, OpenCLR>::value,
324+
R>
325+
convertImpl(T Value) {
326+
return static_cast<R>(Value);
327+
}
328+
329+
// signed to signed
330+
#define __SYCL_GENERATE_CONVERT_IMPL(DestType) \
331+
template <typename T, typename R, rounding_mode roundingMode, \
332+
typename OpenCLT, typename OpenCLR> \
333+
detail::enable_if_t<!std::is_same<T, R>::value && \
334+
is_sint_to_sint<T, R>::value && \
335+
(std::is_same<OpenCLR, DestType>::value || \
336+
std::is_same<OpenCLR, signed char>::value && \
337+
std::is_same<DestType, char>::value) && \
338+
!std::is_same<OpenCLT, OpenCLR>::value, \
339+
R> \
340+
convertImpl(T Value) { \
341+
OpenCLT OpValue = cl::sycl::detail::convertDataToType<T, OpenCLT>(Value); \
342+
return __spirv_SConvert##_R##DestType(OpValue); \
343+
}
344+
345+
__SYCL_GENERATE_CONVERT_IMPL(char)
346+
__SYCL_GENERATE_CONVERT_IMPL(short)
347+
__SYCL_GENERATE_CONVERT_IMPL(int)
348+
__SYCL_GENERATE_CONVERT_IMPL(long)
349+
__SYCL_GENERATE_CONVERT_IMPL(longlong)
350+
351+
#undef __SYCL_GENERATE_CONVERT_IMPL
352+
353+
// unsigned to unsigned
354+
#define __SYCL_GENERATE_CONVERT_IMPL(DestType) \
355+
template <typename T, typename R, rounding_mode roundingMode, \
356+
typename OpenCLT, typename OpenCLR> \
357+
detail::enable_if_t<!std::is_same<T, R>::value && \
358+
is_uint_to_uint<T, R>::value && \
359+
std::is_same<OpenCLR, DestType>::value && \
360+
!std::is_same<OpenCLT, OpenCLR>::value, \
361+
R> \
362+
convertImpl(T Value) { \
363+
OpenCLT OpValue = cl::sycl::detail::convertDataToType<T, OpenCLT>(Value); \
364+
return __spirv_UConvert##_R##DestType(OpValue); \
365+
}
366+
367+
__SYCL_GENERATE_CONVERT_IMPL(uchar)
368+
__SYCL_GENERATE_CONVERT_IMPL(ushort)
369+
__SYCL_GENERATE_CONVERT_IMPL(uint)
370+
__SYCL_GENERATE_CONVERT_IMPL(ulong)
371+
372+
#undef __SYCL_GENERATE_CONVERT_IMPL
373+
374+
// unsigned to (from) signed
375+
template <typename T, typename R, rounding_mode roundingMode, typename OpenCLT,
376+
typename OpenCLR>
377+
detail::enable_if_t<is_sint_to_from_uint<T, R>::value, R> convertImpl(T Value) {
378+
return static_cast<R>(Value);
379+
}
380+
381+
// sint to float
382+
#define __SYCL_GENERATE_CONVERT_IMPL(SPIRVOp, DestType) \
383+
template <typename T, typename R, rounding_mode roundingMode, \
384+
typename OpenCLT, typename OpenCLR> \
385+
detail::enable_if_t< \
386+
is_sint_to_float<T, R>::value && std::is_same<R, DestType>::value, R> \
387+
convertImpl(T Value) { \
388+
OpenCLT OpValue = cl::sycl::detail::convertDataToType<T, OpenCLT>(Value); \
389+
return __spirv_Convert##SPIRVOp##_R##DestType(OpValue); \
390+
}
391+
392+
__SYCL_GENERATE_CONVERT_IMPL(SToF, half)
393+
__SYCL_GENERATE_CONVERT_IMPL(SToF, float)
394+
__SYCL_GENERATE_CONVERT_IMPL(SToF, double)
395+
396+
#undef __SYCL_GENERATE_CONVERT_IMPL
397+
398+
// uint to float
399+
#define __SYCL_GENERATE_CONVERT_IMPL(SPIRVOp, DestType) \
400+
template <typename T, typename R, rounding_mode roundingMode, \
401+
typename OpenCLT, typename OpenCLR> \
402+
detail::enable_if_t< \
403+
is_uint_to_float<T, R>::value && std::is_same<R, DestType>::value, R> \
404+
convertImpl(T Value) { \
405+
OpenCLT OpValue = cl::sycl::detail::convertDataToType<T, OpenCLT>(Value); \
406+
return __spirv_Convert##SPIRVOp##_R##DestType(OpValue); \
407+
}
408+
409+
__SYCL_GENERATE_CONVERT_IMPL(UToF, half)
410+
__SYCL_GENERATE_CONVERT_IMPL(UToF, float)
411+
__SYCL_GENERATE_CONVERT_IMPL(UToF, double)
412+
413+
#undef __SYCL_GENERATE_CONVERT_IMPL
414+
415+
// float to float
416+
#define __SYCL_GENERATE_CONVERT_IMPL(DestType, RoundingMode, \
417+
RoundingModeCondition) \
418+
template <typename T, typename R, rounding_mode roundingMode, \
419+
typename OpenCLT, typename OpenCLR> \
420+
detail::enable_if_t<!std::is_same<T, R>::value && \
421+
is_float_to_float<T, R>::value && \
422+
std::is_same<R, DestType>::value && \
423+
RoundingModeCondition<roundingMode>::value, \
424+
R> \
425+
convertImpl(T Value) { \
426+
OpenCLT OpValue = cl::sycl::detail::convertDataToType<T, OpenCLT>(Value); \
427+
return __spirv_FConvert##_R##DestType##_##RoundingMode(OpValue); \
428+
}
429+
430+
#define __SYCL_GENERATE_CONVERT_IMPL_FOR_ROUNDING_MODE(RoundingMode, \
431+
RoundingModeCondition) \
432+
__SYCL_GENERATE_CONVERT_IMPL(double, RoundingMode, RoundingModeCondition) \
433+
__SYCL_GENERATE_CONVERT_IMPL(float, RoundingMode, RoundingModeCondition) \
434+
__SYCL_GENERATE_CONVERT_IMPL(half, RoundingMode, RoundingModeCondition)
435+
436+
__SYCL_GENERATE_CONVERT_IMPL_FOR_ROUNDING_MODE(rte, RteOrAutomatic)
437+
__SYCL_GENERATE_CONVERT_IMPL_FOR_ROUNDING_MODE(rtz, Rtz)
438+
__SYCL_GENERATE_CONVERT_IMPL_FOR_ROUNDING_MODE(rtp, Rtp)
439+
__SYCL_GENERATE_CONVERT_IMPL_FOR_ROUNDING_MODE(rtn, Rtn)
440+
441+
#undef __SYCL_GENERATE_CONVERT_IMPL_FOR_ROUNDING_MODE
442+
#undef __SYCL_GENERATE_CONVERT_IMPL
443+
444+
// float to int
284445
#define __SYCL_GENERATE_CONVERT_IMPL(SPIRVOp, DestType, RoundingMode, \
285446
RoundingModeCondition) \
286-
template <typename T, typename R, rounding_mode roundingMode> \
447+
template <typename T, typename R, rounding_mode roundingMode, \
448+
typename OpenCLT, typename OpenCLR> \
287449
detail::enable_if_t<is_float_to_int<T, R>::value && \
288-
std::is_same<R, DestType>::value && \
450+
(std::is_same<OpenCLR, DestType>::value || \
451+
std::is_same<OpenCLR, signed char>::value && \
452+
std::is_same<DestType, char>::value) && \
289453
RoundingModeCondition<roundingMode>::value, \
290454
R> \
291455
convertImpl(T Value) { \
292-
using OpenCLT = cl::sycl::detail::ConvertToOpenCLType_t<T>; \
293456
OpenCLT OpValue = cl::sycl::detail::convertDataToType<T, OpenCLT>(Value); \
294457
return __spirv_Convert##SPIRVOp##_R##DestType##_##RoundingMode(OpValue); \
295458
}
@@ -319,6 +482,18 @@ __SYCL_GENERATE_CONVERT_IMPL_FOR_ROUNDING_MODE(rtn, Rtn)
319482
#undef __SYCL_GENERATE_CONVERT_IMPL_FOR_ROUNDING_MODE
320483
#undef __SYCL_GENERATE_CONVERT_IMPL
321484

485+
// Back up
486+
template <typename T, typename R, rounding_mode roundingMode, typename OpenCLT,
487+
typename OpenCLR>
488+
detail::enable_if_t<
489+
(!is_standard_type<T>::value && !is_standard_type<OpenCLT>::value ||
490+
!is_standard_type<R>::value && !is_standard_type<OpenCLR>::value) &&
491+
!std::is_same<OpenCLT, OpenCLR>::value,
492+
R>
493+
convertImpl(T Value) {
494+
return static_cast<R>(Value);
495+
}
496+
322497
#endif // __SYCL_DEVICE_ONLY__
323498

324499
} // namespace detail
@@ -627,9 +802,13 @@ template <typename Type, int NumElements> class vec {
627802
detail::is_floating_point<convertT>::value,
628803
"Unsupported convertT");
629804
vec<convertT, NumElements> Result;
805+
using OpenCLT = detail::ConvertToOpenCLType_t<DataT>;
806+
using OpenCLR = detail::ConvertToOpenCLType_t<convertT>;
630807
for (size_t I = 0; I < NumElements; ++I) {
631808
Result.setValue(
632-
I, detail::convertImpl<DataT, convertT, roundingMode>(getValue(I)));
809+
I,
810+
detail::convertImpl<DataT, convertT, roundingMode, OpenCLT, OpenCLR>(
811+
getValue(I)));
633812
}
634813
return Result;
635814
}

0 commit comments

Comments
 (0)