Skip to content

Commit 80d23d0

Browse files
authored
[SYCL][COMPAT] util header split in math and util headers (#12957)
This PR prepares SYCLcompat to receive multiple PRs containing updates to the helper headers. It will substitute the previous opened PR as it grew in scope too much, as in case of issues it would be difficult to track and resolve effectively the problem. Edit: Previous PR was #11267
1 parent f3abf58 commit 80d23d0

File tree

10 files changed

+268
-222
lines changed

10 files changed

+268
-222
lines changed

sycl/doc/syclcompat/README.md

Lines changed: 52 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,16 +1068,6 @@ Functionality is provided to represent a pair of integers as a `double`.
10681068
in the high & low 32-bits respectively. `cast_double_to_int` casts the high or
10691069
low 32-bits back into an integer.
10701070

1071-
`syclcompat::fast_length` provides a wrapper to SYCL's
1072-
`fast_length(sycl::vec<float,N>)` that accepts arguments for a C++ array and a
1073-
length.
1074-
1075-
`vectorized_max` and `vectorized_min` are binary operations returning the
1076-
max/min of two arguments, where each argument is treated as a `sycl::vec` type.
1077-
`vectorized_isgreater` performs elementwise `isgreater`, treating each argument
1078-
as a vector of elements, and returning `0` for vector components for which
1079-
`isgreater` is false, and `-1` when true.
1080-
10811071
`reverse_bits` reverses the bits of a 32-bit unsigned integer, `ffs` returns the
10821072
position of the first least significant set bit in an integer.
10831073
`byte_level_permute` returns a byte-permutation of two input unsigned integers,
@@ -1093,61 +1083,34 @@ functionality to `sycl::select_from_group`, `sycl::shift_group_left`,
10931083
However, they provide an optional argument to represent the `logical_group` size
10941084
(default 32).
10951085

1096-
The functions `cmul`,`cdiv`,`cabs`, and `conj` define complex math operations
1097-
which accept `sycl::vec<T,2>` arguments representing complex values.
1098-
10991086
```c++
11001087
namespace syclcompat {
11011088

11021089
inline int cast_double_to_int(double d, bool use_high32 = true);
11031090

11041091
inline double cast_ints_to_double(int high32, int low32);
11051092

1106-
inline float fast_length(const float *a, int len);
1107-
1108-
template <typename S, typename T> inline T vectorized_max(T a, T b);
1109-
1110-
template <typename S, typename T> inline T vectorized_min(T a, T b);
1111-
1112-
template <typename S, typename T> inline T vectorized_isgreater(T a, T b);
1113-
1114-
template <>
1115-
inline unsigned vectorized_isgreater<sycl::ushort2, unsigned>(unsigned a,
1116-
unsigned b);
1117-
1118-
template <typename T> inline T reverse_bits(T a);
1119-
11201093
inline unsigned int byte_level_permute(unsigned int a, unsigned int b,
11211094
unsigned int s);
11221095

1123-
template <typename T> inline int ffs(T a);
1096+
template <typename ValueT> inline int ffs(ValueT a);
11241097

1125-
template <typename T>
1126-
T select_from_sub_group(sycl::sub_group g, T x, int remote_local_id,
1098+
template <typename ValueT>
1099+
ValueT select_from_sub_group(sycl::sub_group g, ValueT x, int remote_local_id,
11271100
int logical_sub_group_size = 32);
11281101

1129-
template <typename T>
1130-
T shift_sub_group_left(sycl::sub_group g, T x, unsigned int delta,
1102+
template <typename ValueT>
1103+
ValueT shift_sub_group_left(sycl::sub_group g, ValueT x, unsigned int delta,
11311104
int logical_sub_group_size = 32);
11321105

1133-
template <typename T>
1134-
T shift_sub_group_right(sycl::sub_group g, T x, unsigned int delta,
1106+
template <typename ValueT>
1107+
ValueT shift_sub_group_right(sycl::sub_group g, ValueT x, unsigned int delta,
11351108
int logical_sub_group_size = 32);
11361109

1137-
template <typename T>
1138-
T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask,
1110+
template <typename ValueT>
1111+
ValueT permute_sub_group_by_xor(sycl::sub_group g, ValueT x, unsigned int mask,
11391112
int logical_sub_group_size = 32);
11401113

1141-
template <typename T>
1142-
sycl::vec<T, 2> cmul(sycl::vec<T, 2> x, sycl::vec<T, 2> y);
1143-
1144-
template <typename T>
1145-
sycl::vec<T, 2> cdiv(sycl::vec<T, 2> x, sycl::vec<T, 2> y);
1146-
1147-
template <typename T> T cabs(sycl::vec<T, 2> x);
1148-
1149-
template <typename T> sycl::vec<T, 2> conj(sycl::vec<T, 2> x);
1150-
11511114
} // namespace syclcompat
11521115
```
11531116
@@ -1211,7 +1174,7 @@ int get_sycl_language_version();
12111174
} // namespace syclcompat
12121175
```
12131176
1214-
#### Kernel Helper Functions
1177+
### Kernel Helper Functions
12151178
12161179
Kernel helper functions provide a structure `kernel_function_info` to keep SYCL
12171180
kernel information, and provide a utility function `get_kernel_function_info()`
@@ -1232,6 +1195,47 @@ static kernel_function_info get_kernel_function_info(const void *function);
12321195
} // namespace syclcompat
12331196
```
12341197

1198+
### Math Functions
1199+
1200+
`syclcompat::fast_length` provides a wrapper to SYCL's
1201+
`fast_length(sycl::vec<float,N>)` that accepts arguments for a C++ array and a
1202+
length.
1203+
1204+
`vectorized_max` and `vectorized_min` are binary operations returning the
1205+
max/min of two arguments, where each argument is treated as a `sycl::vec` type.
1206+
`vectorized_isgreater` performs elementwise `isgreater`, treating each argument
1207+
as a vector of elements, and returning `0` for vector components for which
1208+
`isgreater` is false, and `-1` when true.
1209+
1210+
The functions `cmul`,`cdiv`,`cabs`, and `conj` define complex math operations
1211+
which accept `sycl::vec<T,2>` arguments representing complex values.
1212+
1213+
```cpp
1214+
inline float fast_length(const float *a, int len);
1215+
1216+
template <typename S, typename T> inline T vectorized_max(T a, T b);
1217+
1218+
template <typename S, typename T> inline T vectorized_min(T a, T b);
1219+
1220+
template <typename S, typename T> inline T vectorized_isgreater(T a, T b);
1221+
1222+
template <>
1223+
inline unsigned vectorized_isgreater<sycl::ushort2, unsigned>(unsigned a,
1224+
unsigned b);
1225+
1226+
template <typename T>
1227+
sycl::vec<T, 2> cmul(sycl::vec<T, 2> x, sycl::vec<T, 2> y);
1228+
1229+
template <typename T>
1230+
sycl::vec<T, 2> cdiv(sycl::vec<T, 2> x, sycl::vec<T, 2> y);
1231+
1232+
template <typename T> T cabs(sycl::vec<T, 2> x);
1233+
1234+
template <typename T> sycl::vec<T, 2> conj(sycl::vec<T, 2> x);
1235+
1236+
template <typename ValueT> inline ValueT reverse_bits(ValueT a);
1237+
```
1238+
12351239
## Sample Code
12361240
12371241
Below is a simple linear algebra sample, which computes `y = mx + b` implemented
@@ -1331,7 +1335,7 @@ int main(int argc, char **argv) {
13311335
13321336
// Check output
13331337
for (size_t i = 0; i < n_points; i++) {
1334-
assert(h_Y[i] - h_expected[i] < 1e6);
1338+
assert(h_Y[i] - h_expected[i] < 1e-6);
13351339
}
13361340
13371341
// Clean up memory

sycl/include/syclcompat/math.hpp

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
/***************************************************************************
2+
*
3+
* Copyright (C) Codeplay Software Ltd.
4+
*
5+
* Part of the LLVM Project, under the Apache License v2.0 with LLVM
6+
* Exceptions. See https://llvm.org/LICENSE.txt for license information.
7+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*
15+
* SYCL compatibility extension
16+
*
17+
* math.hpp
18+
*
19+
* Description:
20+
* math utilities for the SYCL compatibility extension.
21+
**************************************************************************/
22+
23+
// The original source was under the license below:
24+
//==---- math.hpp ---------------------------------*- C++ -*----------------==//
25+
//
26+
// Copyright (C) Intel Corporation
27+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
28+
// See https://llvm.org/LICENSE.txt for license information.
29+
//
30+
//===----------------------------------------------------------------------===//
31+
32+
#pragma once
33+
34+
#include <sycl/sycl.hpp>
35+
36+
#ifndef SYCL_EXT_ONEAPI_COMPLEX
37+
#define SYCL_EXT_ONEAPI_COMPLEX
38+
#endif
39+
40+
#include <sycl/ext/oneapi/experimental/complex/complex.hpp>
41+
42+
namespace syclcompat {
43+
namespace detail {
44+
45+
namespace complex_namespace = sycl::ext::oneapi::experimental;
46+
47+
template <typename ValueT>
48+
using complex_type = detail::complex_namespace::complex<ValueT>;
49+
50+
} // namespace detail
51+
52+
/// Compute fast_length for variable-length array
53+
/// \param [in] a The array
54+
/// \param [in] len Length of the array
55+
/// \returns The computed fast_length
56+
inline float fast_length(const float *a, int len) {
57+
switch (len) {
58+
case 1:
59+
return sycl::fast_length(a[0]);
60+
case 2:
61+
return sycl::fast_length(sycl::float2(a[0], a[1]));
62+
case 3:
63+
return sycl::fast_length(sycl::float3(a[0], a[1], a[2]));
64+
case 4:
65+
return sycl::fast_length(sycl::float4(a[0], a[1], a[2], a[3]));
66+
case 0:
67+
return 0;
68+
default:
69+
float f = 0;
70+
for (int i = 0; i < len; ++i)
71+
f += a[i] * a[i];
72+
return sycl::sqrt(f);
73+
}
74+
}
75+
76+
/// Compute vectorized max for two values, with each value treated as a vector
77+
/// type \p S
78+
/// \param [in] S The type of the vector
79+
/// \param [in] T The type of the original values
80+
/// \param [in] a The first value
81+
/// \param [in] b The second value
82+
/// \returns The vectorized max of the two values
83+
template <typename S, typename T> inline T vectorized_max(T a, T b) {
84+
sycl::vec<T, 1> v0{a}, v1{b};
85+
auto v2 = v0.template as<S>();
86+
auto v3 = v1.template as<S>();
87+
v2 = sycl::max(v2, v3);
88+
v0 = v2.template as<sycl::vec<T, 1>>();
89+
return v0;
90+
}
91+
92+
/// Compute vectorized min for two values, with each value treated as a vector
93+
/// type \p S
94+
/// \param [in] S The type of the vector
95+
/// \param [in] T The type of the original values
96+
/// \param [in] a The first value
97+
/// \param [in] b The second value
98+
/// \returns The vectorized min of the two values
99+
template <typename S, typename T> inline T vectorized_min(T a, T b) {
100+
sycl::vec<T, 1> v0{a}, v1{b};
101+
auto v2 = v0.template as<S>();
102+
auto v3 = v1.template as<S>();
103+
v2 = sycl::min(v2, v3);
104+
v0 = v2.template as<sycl::vec<T, 1>>();
105+
return v0;
106+
}
107+
108+
/// Compute vectorized isgreater for two values, with each value treated as a
109+
/// vector type \p S
110+
/// \param [in] S The type of the vector
111+
/// \param [in] T The type of the original values
112+
/// \param [in] a The first value
113+
/// \param [in] b The second value
114+
/// \returns The vectorized greater than of the two values
115+
template <typename S, typename T> inline T vectorized_isgreater(T a, T b) {
116+
sycl::vec<T, 1> v0{a}, v1{b};
117+
auto v2 = v0.template as<S>();
118+
auto v3 = v1.template as<S>();
119+
auto v4 = sycl::isgreater(v2, v3);
120+
v0 = v4.template as<sycl::vec<T, 1>>();
121+
return v0;
122+
}
123+
124+
/// Compute vectorized isgreater for two unsigned int values, with each value
125+
/// treated as a vector of two unsigned short
126+
/// \param [in] a The first value
127+
/// \param [in] b The second value
128+
/// \returns The vectorized greater than of the two values
129+
template <>
130+
inline unsigned vectorized_isgreater<sycl::ushort2, unsigned>(unsigned a,
131+
unsigned b) {
132+
sycl::vec<unsigned, 1> v0{a}, v1{b};
133+
auto v2 = v0.template as<sycl::ushort2>();
134+
auto v3 = v1.template as<sycl::ushort2>();
135+
sycl::ushort2 v4;
136+
v4[0] = v2[0] > v3[0];
137+
v4[1] = v2[1] > v3[1];
138+
v0 = v4.template as<sycl::vec<unsigned, 1>>();
139+
return v0;
140+
}
141+
142+
/// Computes the multiplication of two complex numbers.
143+
/// \tparam T Complex element type
144+
/// \param [in] x The first input complex number
145+
/// \param [in] y The second input complex number
146+
/// \returns The result
147+
template <typename T>
148+
sycl::vec<T, 2> cmul(sycl::vec<T, 2> x, sycl::vec<T, 2> y) {
149+
sycl::ext::oneapi::experimental::complex<T> t1(x[0], x[1]), t2(y[0], y[1]);
150+
t1 = t1 * t2;
151+
return sycl::vec<T, 2>(t1.real(), t1.imag());
152+
}
153+
154+
/// Computes the division of two complex numbers.
155+
/// \tparam T Complex element type
156+
/// \param [in] x The first input complex number
157+
/// \param [in] y The second input complex number
158+
/// \returns The result
159+
template <typename T>
160+
sycl::vec<T, 2> cdiv(sycl::vec<T, 2> x, sycl::vec<T, 2> y) {
161+
sycl::ext::oneapi::experimental::complex<T> t1(x[0], x[1]), t2(y[0], y[1]);
162+
t1 = t1 / t2;
163+
return sycl::vec<T, 2>(t1.real(), t1.imag());
164+
}
165+
166+
/// Computes the magnitude of a complex number.
167+
/// \tparam T Complex element type
168+
/// \param [in] x The input complex number
169+
/// \returns The result
170+
template <typename T> T cabs(sycl::vec<T, 2> x) {
171+
sycl::ext::oneapi::experimental::complex<T> t(x[0], x[1]);
172+
return abs(t);
173+
}
174+
175+
/// Computes the complex conjugate of a complex number.
176+
/// \tparam T Complex element type
177+
/// \param [in] x The input complex number
178+
/// \returns The result
179+
template <typename T> sycl::vec<T, 2> conj(sycl::vec<T, 2> x) {
180+
sycl::ext::oneapi::experimental::complex<T> t(x[0], x[1]);
181+
t = conj(t);
182+
return sycl::vec<T, 2>(t.real(), t.imag());
183+
}
184+
185+
} // namespace syclcompat

sycl/include/syclcompat/syclcompat.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,6 @@
2929
#include <syclcompat/id_query.hpp>
3030
#include <syclcompat/kernel.hpp>
3131
#include <syclcompat/launch.hpp>
32+
#include <syclcompat/math.hpp>
3233
#include <syclcompat/memory.hpp>
3334
#include <syclcompat/util.hpp>

0 commit comments

Comments
 (0)