Skip to content

Commit e78c51d

Browse files
[SYCL] Fix vec::convert method.
Fixed non-compiling code. Added support for rounding modes on the host device. Signed-off-by: Alexey Voronov <[email protected]>
1 parent 0e44dd2 commit e78c51d

File tree

2 files changed

+215
-55
lines changed

2 files changed

+215
-55
lines changed

sycl/include/CL/sycl/types.hpp

Lines changed: 76 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,12 @@
4545
#endif // __HAS_EXT_VECTOR_TYPE__
4646

4747
#include <CL/sycl/detail/common.hpp>
48+
#include <CL/sycl/detail/type_traits.hpp>
4849
#include <CL/sycl/half_type.hpp>
4950
#include <CL/sycl/multi_ptr.hpp>
5051

5152
#include <array>
53+
#include <cmath>
5254

5355
// 4.10.1: Scalar data types
5456
// 4.10.2: SYCL vector types
@@ -226,17 +228,75 @@ template <typename T> struct LShift {
226228
}
227229
};
228230

229-
template <typename T, typename convertT, rounding_mode roundingMode>
230-
T convertHelper(const T &Opnd) {
231-
if (roundingMode == rounding_mode::automatic ||
232-
roundingMode == rounding_mode::rtz) {
233-
return static_cast<convertT>(Opnd);
234-
}
235-
if (roundingMode == rounding_mode::rtp) {
236-
return static_cast<convertT>(ceil(Opnd));
237-
}
238-
// roundingMode == rounding_mode::rtn
239-
return static_cast<convertT>(floor(Opnd));
231+
template <typename T>
232+
using is_floating_point =
233+
std::integral_constant<bool, std::is_floating_point<T>::value ||
234+
std::is_same<T, half>::value>;
235+
236+
template <typename T, typename R>
237+
using is_int_to_int =
238+
std::integral_constant<bool, std::is_integral<T>::value &&
239+
std::is_integral<R>::value>;
240+
241+
template <typename T, typename R>
242+
using is_int_to_float =
243+
std::integral_constant<bool, std::is_integral<T>::value &&
244+
detail::is_floating_point<R>::value>;
245+
246+
template <typename T, typename R>
247+
using is_float_to_int =
248+
std::integral_constant<bool, detail::is_floating_point<T>::value &&
249+
std::is_integral<R>::value>;
250+
251+
template <typename T, typename R>
252+
using is_float_to_float =
253+
std::integral_constant<bool, detail::is_floating_point<T>::value &&
254+
detail::is_floating_point<R>::value>;
255+
256+
template <typename T, typename R, rounding_mode roundingMode>
257+
detail::enable_if_t<std::is_same<T, R>::value, R> convertImpl(T Value) {
258+
return Value;
259+
}
260+
261+
template <typename T, typename R, rounding_mode roundingMode>
262+
detail::enable_if_t<!std::is_same<T, R>::value &&
263+
(is_int_to_int<T, R>::value ||
264+
is_int_to_float<T, R>::value ||
265+
is_float_to_float<T, R>::value),
266+
R>
267+
convertImpl(T Value) {
268+
return static_cast<R>(Value);
269+
}
270+
271+
// float to int
272+
template <typename T, typename R, rounding_mode roundingMode>
273+
detail::enable_if_t<!std::is_same<T, R>::value && is_float_to_int<T, R>::value,
274+
R>
275+
convertImpl(T Value) {
276+
#ifndef __SYCL_DEVICE_ONLY__
277+
switch (roundingMode) {
278+
// Round to nearest even is default rounding mode for floating-point types
279+
case rounding_mode::automatic:
280+
// Round to nearest even.
281+
case rounding_mode::rte:
282+
return std::round(Value);
283+
// Round toward zero.
284+
case rounding_mode::rtz:
285+
return std::trunc(Value);
286+
// Round toward positive infinity.
287+
case rounding_mode::rtp:
288+
return std::ceil(Value);
289+
// Round toward negative infinity.
290+
case rounding_mode::rtn:
291+
return std::floor(Value);
292+
default:
293+
assert(!"Unsupported rounding mode!");
294+
return static_cast<R>(Value);
295+
};
296+
#else
297+
// TODO implement device side convertion.
298+
return static_cast<R>(Value);
299+
#endif
240300
}
241301

242302
} // namespace detail
@@ -513,56 +573,17 @@ template <typename Type, int NumElements> class vec {
513573
static constexpr size_t get_count() { return NumElements; }
514574
static constexpr size_t get_size() { return sizeof(m_Data); }
515575

516-
// TODO: convert() for FP to FP. Also, check whether rounding mode handling
517-
// is needed for integers to FP convert.
518-
//
519-
// Convert to same type is no-op.
520-
template <typename convertT, rounding_mode roundingMode>
521-
typename std::enable_if<std::is_same<DataT, convertT>::value,
522-
vec<convertT, NumElements>>::type
523-
convert() const {
524-
return *this;
525-
}
526-
// From Integer to Integer or FP
527-
template <typename convertT, rounding_mode roundingMode>
528-
typename std::enable_if<!std::is_same<DataT, convertT>::value &&
529-
std::is_integral<DataT>::value,
530-
vec<convertT, NumElements>>::type
531-
convert() const {
532-
// Use __SYCL_DEVICE_ONLY__ macro because cast to OpenCL vector type is defined
533-
// by SYCL device compiler only.
534-
#ifdef __SYCL_DEVICE_ONLY__
535-
return vec<convertT, NumElements>{
536-
(typename vec<convertT, NumElements>::DataType)m_Data};
537-
#else
538-
vec<convertT, NumElements> Result;
539-
for (size_t I = 0; I < NumElements; ++I) {
540-
Result.setValue(I, static_cast<convertT>(getValue(I)));
541-
}
542-
return Result;
543-
#endif
544-
}
545-
// From FP to Integer
546576
template <typename convertT, rounding_mode roundingMode>
547-
typename std::enable_if<!std::is_same<DataT, convertT>::value &&
548-
std::is_integral<convertT>::value &&
549-
std::is_floating_point<DataT>::value,
550-
vec<convertT, NumElements>>::type
551-
convert() const {
552-
// Use __SYCL_DEVICE_ONLY__ macro because cast to OpenCL vector type is defined
553-
// by SYCL device compiler only.
554-
#ifdef __SYCL_DEVICE_ONLY__
555-
return vec<convertT, NumElements>{
556-
detail::convertHelper<vec<convertT, NumElements>::DataType,
557-
roundingMode>(m_Data)};
558-
#else
577+
vec<convertT, NumElements> convert() const {
578+
static_assert(std::is_integral<convertT>::value ||
579+
detail::is_floating_point<convertT>::value,
580+
"Unsupported convertT");
559581
vec<convertT, NumElements> Result;
560582
for (size_t I = 0; I < NumElements; ++I) {
561583
Result.setValue(
562-
I, detail::convertHelper<convertT, roundingMode>(getValue(I)));
584+
I, detail::convertImpl<DataT, convertT, roundingMode>(getValue(I)));
563585
}
564586
return Result;
565-
#endif
566587
}
567588

568589
template <typename asT>

sycl/test/basic_tests/vec_convert.cpp

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
// RUN: %clangxx -fsycl %s -o %t.out -lOpenCL
2+
// RUN: env SYCL_DEVICE_TYPE=HOST %t.out
3+
// RUNx: %CPU_RUN_PLACEHOLDER %t.out
4+
// RUNx: %GPU_RUN_PLACEHOLDER %t.out
5+
// RUNx: %ACC_RUN_PLACEHOLDER %t.out
6+
//==------------ vec_convert.cpp - SYCL vec class convert method test ------==//
7+
//
8+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9+
// See https://llvm.org/LICENSE.txt for license information.
10+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include <CL/sycl.hpp>
15+
16+
#include <cassert>
17+
18+
// TODO uncomment run lines on non-host devices when the rounding modes will
19+
// be implemented.
20+
21+
using namespace cl::sycl;
22+
23+
template <typename T, typename convertT, int roundingMode> class kernel_name;
24+
25+
template <int N> struct helper;
26+
27+
template <> struct helper<0> {
28+
template <typename T, int NumElements>
29+
static void compare(const vec<T, NumElements> &x,
30+
const vec<T, NumElements> &y) {
31+
const T xs = x.template swizzle<0>();
32+
const T ys = y.template swizzle<0>();
33+
assert(xs == ys);
34+
}
35+
};
36+
37+
template <int N> struct helper {
38+
template <typename T, int NumElements>
39+
static void compare(const vec<T, NumElements> &x,
40+
const vec<T, NumElements> &y) {
41+
const T xs = x.template swizzle<N>();
42+
const T ys = y.template swizzle<N>();
43+
helper<N - 1>::compare(x, y);
44+
assert(xs == ys);
45+
}
46+
};
47+
48+
template <typename T, typename convertT, int NumElements,
49+
rounding_mode roundingMode>
50+
void test(const vec<T, NumElements> &ToConvert,
51+
const vec<convertT, NumElements> &Expected) {
52+
vec<convertT, NumElements> Converted{0};
53+
{
54+
buffer<vec<convertT, NumElements>, 1> Buffer{&Converted, range<1>{1}};
55+
queue Queue;
56+
Queue.submit([&](handler &CGH) {
57+
accessor<vec<convertT, NumElements>, 1, access::mode::write> Accessor(
58+
Buffer, CGH);
59+
CGH.single_task<class kernel_name<T, convertT, static_cast<int>(roundingMode)>>([=]() {
60+
Accessor[0] = ToConvert.template convert<convertT, roundingMode>();
61+
});
62+
});
63+
}
64+
helper<NumElements - 1>::compare(Converted, Expected);
65+
}
66+
67+
int main() {
68+
// automatic
69+
test<int, int, 8, rounding_mode::automatic>(
70+
int8{2, 3, 3, -2, -3, -3, 0, 0},
71+
int8{2, 3, 3, -2, -3, -3, 0, 0});
72+
test<float, int, 8, rounding_mode::automatic>(
73+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
74+
int8{2, 3, 3, -2, -3, -3, 0, 0});
75+
test<int, float, 8, rounding_mode::automatic>(
76+
int8{2, 3, 3, -2, -3, -3, 0, 0},
77+
float8{2.f, 3.f, 3.f, -2.f, -3.f, -3.f, 0.f, 0.f});
78+
test<float, float, 8, rounding_mode::automatic>(
79+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
80+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f});
81+
82+
// rte
83+
test<int, int, 8, rounding_mode::rte>(
84+
int8{2, 3, 3, -2, -3, -3, 0, 0},
85+
int8{2, 3, 3, -2, -3, -3, 0, 0});
86+
test<float, int, 8, rounding_mode::rte>(
87+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
88+
int8{2, 3, 3, -2, -3, -3, 0, 0});
89+
test<int, float, 8, rounding_mode::rte>(
90+
int8{2, 3, 3, -2, -3, -3, 0, 0},
91+
float8{2.f, 3.f, 3.f, -2.f, -3.f, -3.f, 0.f, 0.f});
92+
test<float, float, 8, rounding_mode::rte>(
93+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
94+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f});
95+
96+
// rtz
97+
test<int, int, 8, rounding_mode::rtz>(
98+
int8{2, 3, 3, -2, -3, -3, 0, 0},
99+
int8{2, 3, 3, -2, -3, -3, 0, 0});
100+
test<float, int, 8, rounding_mode::rtz>(
101+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
102+
int8{2, 2, 2, -2, -2, -2, 0, 0});
103+
test<int, float, 8, rounding_mode::rtz>(
104+
int8{2, 3, 3, -2, -3, -3, 0, 0},
105+
float8{2.f, 3.f, 3.f, -2.f, -3.f, -3.f, 0.f, 0.f});
106+
test<float, float, 8, rounding_mode::rtz>(
107+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
108+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f});
109+
110+
// rtp
111+
test<int, int, 8, rounding_mode::rtp>(
112+
int8{2, 3, 3, -2, -3, -3, 0, 0},
113+
int8{2, 3, 3, -2, -3, -3, 0, 0});
114+
test<float, int, 8, rounding_mode::rtp>(
115+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
116+
int8{3, 3, 3, -2, -2, -2, 0, 0});
117+
test<int, float, 8, rounding_mode::rtp>(
118+
int8{2, 3, 3, -2, -3, -3, 0, 0},
119+
float8{2.f, 3.f, 3.f, -2.f, -3.f, -3.f, 0.f, 0.f});
120+
test<float, float, 8, rounding_mode::rtp>(
121+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
122+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f});
123+
124+
// rtn
125+
test<int, int, 8, rounding_mode::rtn>(
126+
int8{2, 3, 3, -2, -3, -3, 0, 0},
127+
int8{2, 3, 3, -2, -3, -3, 0, 0});
128+
test<float, int, 8, rounding_mode::rtn>(
129+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
130+
int8{2, 2, 2, -3, -3, -3, 0, 0});
131+
test<int, float, 8, rounding_mode::rtn>(
132+
int8{2, 3, 3, -2, -3, -3, 0, 0},
133+
float8{2.f, 3.f, 3.f, -2.f, -3.f, -3.f, 0.f, 0.f});
134+
test<float, float, 8, rounding_mode::rtn>(
135+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f},
136+
float8{+2.3f, +2.5f, +2.7f, -2.3f, -2.5f, -2.7f, 0.f, 0.f});
137+
138+
return 0;
139+
}

0 commit comments

Comments
 (0)