Skip to content

Commit b7e5f95

Browse files
[SYCL] Don't include <complex> from <sycl/sycl.hpp>
This addresses a part of a bigger issue of polluting the global namespace with math functions by implicitly including <cmath> from sycl headers. Internally, <complex> includes <cmath> so we shouldn't be including it as well, which this patch implements. I'm doing it by providing our own version of <complex> that does `#include_next <complex>` and also provides needed functions/template specializations to support `std::complex`-related functionality whenever user program includes <complex> on its own. System include path for this is only provided in SYCL mode so non-SYCL compilation flow is unaffected.
1 parent b290e9f commit b7e5f95

File tree

13 files changed

+129
-63
lines changed

13 files changed

+129
-63
lines changed

clang/lib/Driver/ToolChains/SYCL.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,9 +1128,16 @@ void SYCLToolChain::AddSYCLIncludeArgs(const clang::driver::Driver &Driver,
11281128
llvm::sys::path::append(P, "include");
11291129
SmallString<128> SYCLP(P);
11301130
llvm::sys::path::append(SYCLP, "sycl");
1131+
// This is used to provide our wrappers around STL headers that provide
1132+
// additional functions/template specializations when the user includes those
1133+
// STL headers in their programs (e.g., <complex>).
1134+
SmallString<128> STL_WRAPPERS_PATH(SYCLP);
1135+
llvm::sys::path::append(STL_WRAPPERS_PATH, "stl_wrappers");
11311136
CC1Args.push_back("-internal-isystem");
11321137
CC1Args.push_back(DriverArgs.MakeArgString(SYCLP));
11331138
CC1Args.push_back("-internal-isystem");
1139+
CC1Args.push_back(DriverArgs.MakeArgString(STL_WRAPPERS_PATH));
1140+
CC1Args.push_back("-internal-isystem");
11341141
CC1Args.push_back(DriverArgs.MakeArgString(P));
11351142
}
11361143

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,19 +1212,6 @@ __CLC_BF16_SCAL_VEC(uint32_t)
12121212
#undef __CLC_BF16_SCAL_VEC
12131213
#undef __CLC_BF16
12141214

1215-
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
1216-
__SYCL_EXPORT __spv::complex_half
1217-
__spirv_GroupCMulINTEL(unsigned int, unsigned int,
1218-
__spv::complex_half) noexcept;
1219-
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
1220-
__SYCL_EXPORT __spv::complex_float
1221-
__spirv_GroupCMulINTEL(unsigned int, unsigned int,
1222-
__spv::complex_float) noexcept;
1223-
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
1224-
__SYCL_EXPORT __spv::complex_double
1225-
__spirv_GroupCMulINTEL(unsigned int, unsigned int,
1226-
__spv::complex_double) noexcept;
1227-
12281215
extern __DPCPP_SYCL_EXTERNAL int32_t __spirv_BuiltInGlobalHWThreadIDINTEL();
12291216
extern __DPCPP_SYCL_EXTERNAL int32_t __spirv_BuiltInSubDeviceIDINTEL();
12301217

sycl/include/CL/__spirv/spirv_types.hpp

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include <sycl/detail/defines.hpp> // for SYCL_EXT_ONEAPI_MATRIX_VERSION
1212
#include <sycl/half_type.hpp> // for half
1313

14-
#include <complex> // for complex
1514
#include <cstddef> // for size_t
1615
#include <cstdint> // for uint32_t
1716

@@ -130,27 +129,6 @@ enum class MatrixLayout : uint32_t {
130129

131130
enum class MatrixUse : uint32_t { MatrixA = 0, MatrixB = 1, Accumulator = 2 };
132131

133-
struct complex_float {
134-
complex_float() = default;
135-
complex_float(std::complex<float> x) : real(x.real()), imag(x.imag()) {}
136-
operator std::complex<float>() { return {real, imag}; }
137-
float real, imag;
138-
};
139-
140-
struct complex_double {
141-
complex_double() = default;
142-
complex_double(std::complex<double> x) : real(x.real()), imag(x.imag()) {}
143-
operator std::complex<double>() { return {real, imag}; }
144-
double real, imag;
145-
};
146-
147-
struct complex_half {
148-
complex_half() = default;
149-
complex_half(std::complex<sycl::half> x) : real(x.real()), imag(x.imag()) {}
150-
operator std::complex<sycl::half>() { return {real, imag}; }
151-
sycl::half real, imag;
152-
};
153-
154132
#if (SYCL_EXT_ONEAPI_MATRIX_VERSION > 1)
155133
template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
156134
Scope::Flag S = Scope::Flag::Subgroup,

sycl/include/sycl/builtins.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
#include <sycl/builtins_scalar_gen.hpp>
1616
#include <sycl/builtins_vector_gen.hpp>
1717

18+
// We don't use the same exception specifier as <cmath> so we get warnings if
19+
// our code is processed before STL's <cmath>.
20+
// TODO: We should remove this dependency alltogether in a subsequent patch.
21+
#include <cmath>
22+
1823
#ifdef __SYCL_DEVICE_ONLY__
1924
extern "C" {
2025
extern __DPCPP_SYCL_EXTERNAL int abs(int x);

sycl/include/sycl/detail/generic_type_traits.hpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
#pragma once
1010

11-
#include <CL/__spirv/spirv_types.hpp> // for complex_double, comple...
1211
#include <sycl/access/access.hpp> // for decorated, address_space
1312
#include <sycl/aliases.hpp> // for half, cl_char, cl_double
1413
#include <sycl/detail/generic_type_lists.hpp> // for nonconst_address_space...
@@ -18,7 +17,6 @@
1817
#include <sycl/half_type.hpp> // for BIsRepresentationT
1918
#include <sycl/multi_ptr.hpp> // for multi_ptr, address_spa...
2019

21-
#include <complex> // for complex
2220
#include <cstddef> // for byte
2321
#include <cstdint> // for uint8_t
2422
#include <limits> // for numeric_limits
@@ -485,13 +483,17 @@ using select_cl_scalar_float_t =
485483
select_apply_cl_scalar_t<T, std::false_type, sycl::opencl::cl_half,
486484
sycl::opencl::cl_float, sycl::opencl::cl_double>;
487485

486+
// Use SFINAE so that std::complex specialization could be implemented in
487+
// include/sycl/stl_wrappers/complex that would only be available if STL's
488+
// <complex> is included by users.
489+
template <typename T, typename = void>
490+
struct select_cl_scalar_complex_or_T {
491+
using type = T;
492+
};
493+
488494
template <typename T>
489-
using select_cl_scalar_complex_or_T_t = std::conditional_t<
490-
std::is_same_v<T, std::complex<float>>, __spv::complex_float,
491-
std::conditional_t<std::is_same_v<T, std::complex<double>>,
492-
__spv::complex_double,
493-
std::conditional_t<std::is_same_v<T, std::complex<half>>,
494-
__spv::complex_half, T>>>;
495+
using select_cl_scalar_complex_or_T_t =
496+
typename select_cl_scalar_complex_or_T<T>::type;
495497

496498
template <typename T>
497499
using select_cl_scalar_integral_t =

sycl/include/sycl/ext/oneapi/functional.hpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,13 @@ struct GroupOpTag<T, std::enable_if_t<detail::is_sgenfloat<T>::value>> {
5353
using type = GroupOpFP;
5454
};
5555

56-
template <typename T>
57-
struct GroupOpTag<
58-
T, std::enable_if_t<std::is_same<T, std::complex<half>>::value ||
59-
std::is_same<T, std::complex<float>>::value ||
60-
std::is_same<T, std::complex<double>>::value>> {
61-
using type = GroupOpC;
62-
};
63-
6456
template <typename T>
6557
struct GroupOpTag<T, std::enable_if_t<detail::is_genbool<T>::value>> {
6658
using type = GroupOpBool;
6759
};
6860

61+
// GroupOpC (std::complex) is handled in sycl/stl_wrappers/complex.
62+
6963
#define __SYCL_CALC_OVERLOAD(GroupTag, SPIRVOperation, BinaryOperation) \
7064
template <__spv::GroupOperation O, typename Group, typename T> \
7165
static T calc(Group g, GroupTag, T x, BinaryOperation) { \

sycl/include/sycl/group_algorithm.hpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
#endif
3232
#endif
3333

34-
#include <complex> // for complex
3534
#include <stddef.h> // for size_t
3635
#include <type_traits> // for enable_if_t, decay_t, integra...
3736

@@ -127,13 +126,11 @@ using is_multiplies = std::integral_constant<
127126
std::is_same_v<BinaryOperation, sycl::multiplies<void>>>;
128127

129128
// ---- is_complex
130-
// NOTE: std::complex<long double> not yet supported by group algorithms.
131-
template <typename T>
132-
struct is_complex
133-
: std::integral_constant<bool,
134-
std::is_same_v<T, std::complex<half>> ||
135-
std::is_same_v<T, std::complex<float>> ||
136-
std::is_same_v<T, std::complex<double>>> {};
129+
// Use SFINAE so that the "true" branch could be implemented in
130+
// include/sycl/stl_wrappers/complex that would only be available if STL's
131+
// <complex> is included by users.
132+
template <typename T, typename = void>
133+
struct is_complex : public std::false_type {};
137134

138135
// ---- is_arithmetic_or_complex
139136
template <typename T>

sycl/include/sycl/known_identity.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include <sycl/marray.hpp> // for marray
1616
#include <sycl/types.hpp> // for vec
1717

18-
#include <complex> // for complex
1918
#include <cstddef> // for byte, size_t
2019
#include <functional> // for logical_and, logical_or
2120
#include <limits> // for numeric_limits
@@ -75,9 +74,11 @@ using IsLogicalOR =
7574
std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
7675
std::is_same_v<BinaryOperation, sycl::logical_or<void>>>;
7776

78-
template <typename T>
79-
using isComplex = std::bool_constant<std::is_same_v<T, std::complex<float>> ||
80-
std::is_same_v<T, std::complex<double>>>;
77+
// Use SFINAE so that the "true" branch could be implemented in
78+
// include/sycl/stl_wrappers/complex that would only be available if STL's
79+
// <complex> is included by users.
80+
template <typename T, typename = void>
81+
struct isComplex : public std::false_type {};
8182

8283
// Identity = 0
8384
template <typename T, class BinaryOperation>
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#pragma once
2+
3+
#include_next <complex>
4+
5+
#include <type_traits>
6+
7+
#include <CL/__spirv/spirv_ops.hpp> // for __SYCL_CONVERGENT__
8+
#include <sycl/half_type.hpp> // for half
9+
10+
// We provide std::complex specializations here for the following:
11+
// select_cl_scalar_complex_or_T:
12+
#include <sycl/detail/generic_type_traits.hpp>
13+
// sycl::detail::GroupOpTag:
14+
#include <sycl/ext/oneapi/functional.hpp>
15+
// sycl::detail::is_complex:
16+
#include <sycl/group_algorithm.hpp>
17+
// sycl::detail::isComplex
18+
#include <sycl/known_identity.hpp>
19+
20+
namespace __spv {
21+
struct complex_float {
22+
complex_float() = default;
23+
complex_float(std::complex<float> x) : real(x.real()), imag(x.imag()) {}
24+
operator std::complex<float>() { return {real, imag}; }
25+
float real, imag;
26+
};
27+
28+
struct complex_double {
29+
complex_double() = default;
30+
complex_double(std::complex<double> x) : real(x.real()), imag(x.imag()) {}
31+
operator std::complex<double>() { return {real, imag}; }
32+
double real, imag;
33+
};
34+
35+
struct complex_half {
36+
complex_half() = default;
37+
complex_half(std::complex<sycl::half> x) : real(x.real()), imag(x.imag()) {}
38+
operator std::complex<sycl::half>() { return {real, imag}; }
39+
sycl::half real, imag;
40+
};
41+
} // namespace __spv
42+
43+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
44+
__SYCL_EXPORT __spv::complex_half
45+
__spirv_GroupCMulINTEL(unsigned int, unsigned int,
46+
__spv::complex_half) noexcept;
47+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
48+
__SYCL_EXPORT __spv::complex_float
49+
__spirv_GroupCMulINTEL(unsigned int, unsigned int,
50+
__spv::complex_float) noexcept;
51+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
52+
__SYCL_EXPORT __spv::complex_double
53+
__spirv_GroupCMulINTEL(unsigned int, unsigned int,
54+
__spv::complex_double) noexcept;
55+
56+
namespace sycl {
57+
inline namespace _V1 {
58+
namespace detail {
59+
template <typename T>
60+
struct isComplex<T, std::enable_if_t<std::is_same_v<T, std::complex<float>> ||
61+
std::is_same_v<T, std::complex<double>>>>
62+
: public std::true_type {};
63+
64+
// NOTE: std::complex<long double> not yet supported by group algorithms.
65+
template <typename T>
66+
struct is_complex<T, std::enable_if_t<std::is_same_v<T, std::complex<half>> ||
67+
std::is_same_v<T, std::complex<float>> ||
68+
std::is_same_v<T, std::complex<double>>>>
69+
: public std::true_type {};
70+
71+
#ifdef __SYCL_DEVICE_ONLY__
72+
template <typename T>
73+
struct GroupOpTag<
74+
T, std::enable_if_t<std::is_same<T, std::complex<half>>::value ||
75+
std::is_same<T, std::complex<float>>::value ||
76+
std::is_same<T, std::complex<double>>::value>> {
77+
using type = GroupOpC;
78+
};
79+
#endif
80+
81+
template <typename T>
82+
struct select_cl_scalar_complex_or_T<T,
83+
std::enable_if_t<is_complex<T>::value>> {
84+
using type = std::conditional_t<
85+
std::is_same_v<T, std::complex<float>>, __spv::complex_float,
86+
std::conditional_t<std::is_same_v<T, std::complex<double>>,
87+
__spv::complex_double, __spv::complex_half>>;
88+
};
89+
} // namespace detail
90+
} // namespace _V1
91+
} // namespace sycl

sycl/test-e2e/GroupAlgorithm/different_types.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: %{build} -fsycl-device-code-split=per_kernel -I . -o %t.out
22
// RUN: %{run} %t.out
33

4+
#include <complex>
45
#include <cstdint>
56
#include <limits>
67
#include <numeric>

sycl/test-e2e/GroupAlgorithm/exclusive_scan_sycl2020.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "support.h"
66
#include <algorithm>
77
#include <cassert>
8+
#include <complex>
89
#include <iostream>
910
#include <limits>
1011
#include <numeric>

sycl/test-e2e/GroupAlgorithm/inclusive_scan_sycl2020.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "support.h"
66
#include <algorithm>
77
#include <cassert>
8+
#include <complex>
89
#include <iostream>
910
#include <limits>
1011
#include <numeric>

sycl/test-e2e/UserDefinedReductions/user_defined_reductions.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
//
44
// UNSUPPORTED: cuda || hip
55

6+
#include <complex>
67
#include <numeric>
78

89
#include <sycl/ext/oneapi/experimental/user_defined_reductions.hpp>

0 commit comments

Comments
 (0)