Skip to content

Commit 6ed0ab8

Browse files
[SYCL] Don't include <complex> from <sycl/sycl.hpp> (#11196)
This addresses a part of a bigger issue of polluting the global namespace with math functions by implicitly including `<cmath>` from sycl headers. Internally, `<complex>` includes `<cmath>` so we shouldn't be including it as well, which this patch implements. I'm doing it by providing our own version of <complex> that does `#include_next <complex>` and also provides needed functions/template specializations to support `std::complex`-related functionality whenever user program includes <complex> on its own. System include path for this is only provided in SYCL mode, so non-SYCL compilation flow is unaffected. Unfortunately, MSVC doesn't have have `#include_next` so we use a hack to emulate it - via `#include <../include/complex>`.
1 parent c7d3c00 commit 6ed0ab8

File tree

17 files changed

+243
-73
lines changed

17 files changed

+243
-73
lines changed

clang/lib/Driver/ToolChains/Clang.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4853,9 +4853,16 @@ void Clang::ConstructHostCompilerJob(Compilation &C, const JobAction &JA,
48534853
llvm::sys::path::append(BaseDir, "..", "include");
48544854
SmallString<128> SYCLDir(BaseDir);
48554855
llvm::sys::path::append(SYCLDir, "sycl");
4856+
// This is used to provide our wrappers around STL headers that provide
4857+
// additional functions/template specializations when the user includes those
4858+
// STL headers in their programs (e.g., <complex>).
4859+
SmallString<128> STLWrappersDir(SYCLDir);
4860+
llvm::sys::path::append(STLWrappersDir, "stl_wrappers");
48564861
HostCompileArgs.push_back("-I");
48574862
HostCompileArgs.push_back(TCArgs.MakeArgString(SYCLDir));
48584863
HostCompileArgs.push_back("-I");
4864+
HostCompileArgs.push_back(TCArgs.MakeArgString(STLWrappersDir));
4865+
HostCompileArgs.push_back("-I");
48594866
HostCompileArgs.push_back(TCArgs.MakeArgString(BaseDir));
48604867

48614868
if (!OutputAdded) {

clang/lib/Driver/ToolChains/SYCL.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,16 +1122,24 @@ SYCLToolChain::GetCXXStdlibType(const ArgList &Args) const {
11221122
void SYCLToolChain::AddSYCLIncludeArgs(const clang::driver::Driver &Driver,
11231123
const ArgList &DriverArgs,
11241124
ArgStringList &CC1Args) {
1125-
// Add ../include/sycl and ../include (in that order)
1126-
SmallString<128> P(Driver.getInstalledDir());
1127-
llvm::sys::path::append(P, "..");
1128-
llvm::sys::path::append(P, "include");
1129-
SmallString<128> SYCLP(P);
1130-
llvm::sys::path::append(SYCLP, "sycl");
1125+
// Add ../include/sycl, ../include/sycl/stl_wrappers and ../include (in that
1126+
// order).
1127+
SmallString<128> IncludePath(Driver.getInstalledDir());
1128+
llvm::sys::path::append(IncludePath, "..");
1129+
llvm::sys::path::append(IncludePath, "include");
1130+
SmallString<128> SYCLPath(IncludePath);
1131+
llvm::sys::path::append(SYCLPath, "sycl");
1132+
// This is used to provide our wrappers around STL headers that provide
1133+
// additional functions/template specializations when the user includes those
1134+
// STL headers in their programs (e.g., <complex>).
1135+
SmallString<128> STLWrappersPath(SYCLPath);
1136+
llvm::sys::path::append(STLWrappersPath, "stl_wrappers");
11311137
CC1Args.push_back("-internal-isystem");
1132-
CC1Args.push_back(DriverArgs.MakeArgString(SYCLP));
1138+
CC1Args.push_back(DriverArgs.MakeArgString(SYCLPath));
11331139
CC1Args.push_back("-internal-isystem");
1134-
CC1Args.push_back(DriverArgs.MakeArgString(P));
1140+
CC1Args.push_back(DriverArgs.MakeArgString(STLWrappersPath));
1141+
CC1Args.push_back("-internal-isystem");
1142+
CC1Args.push_back(DriverArgs.MakeArgString(IncludePath));
11351143
}
11361144

11371145
void SYCLToolChain::AddClangSystemIncludeArgs(const ArgList &DriverArgs,

clang/test/Driver/sycl-offload.c

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,8 +622,18 @@
622622
// Verify header search dirs are added with -fsycl
623623
// RUN: %clang -### -fsycl %s 2>&1 | FileCheck %s -check-prefixes=CHECK-HEADER-DIR
624624
// RUN: %clang_cl -### -fsycl %s 2>&1 | FileCheck %s -check-prefixes=CHECK-HEADER-DIR
625-
// CHECK-HEADER-DIR: clang{{.*}} "-fsycl-is-device"{{.*}} "-internal-isystem" "{{.*}}bin{{[/\\]+}}..{{[/\\]+}}include{{[/\\]+}}sycl" "-internal-isystem" "{{.*}}bin{{[/\\]+}}..{{[/\\]+}}include"
626-
// CHECK-HEADER-DIR: clang{{.*}} "-fsycl-is-host"{{.*}} "-internal-isystem" "{{.*}}bin{{[/\\]+}}..{{[/\\]+}}include{{[/\\]+}}sycl" "-internal-isystem" "{{.*}}bin{{[/\\]+}}..{{[/\\]+}}include"{{.*}}
625+
// CHECK-HEADER-DIR: clang{{.*}} "-fsycl-is-device"
626+
// CHECK-HEADER-DIR-SAME: "-internal-isystem" "[[ROOT:[^"]*]]bin{{[/\\]+}}..{{[/\\]+}}include{{[/\\]+}}sycl"
627+
// CHECK-HEADER-DIR-NOT: -internal-isystem
628+
// CHECK-HEADER-DIR-SAME: "-internal-isystem" "[[ROOT]]bin{{[/\\]+}}..{{[/\\]+}}include{{[/\\]+}}sycl{{[/\\]+}}stl_wrappers"
629+
// CHECK-HEADER-DIR-NOT: -internal-isystem
630+
// CHECK-HEADER-DIR-SAME: "-internal-isystem" "[[ROOT]]bin{{[/\\]+}}..{{[/\\]+}}include"
631+
// CHECK-HEADER-DIR: clang{{.*}} "-fsycl-is-host"
632+
// CHECK-HEADER-DIR-SAME: "-internal-isystem" "[[ROOT]]bin{{[/\\]+}}..{{[/\\]+}}include{{[/\\]+}}sycl"
633+
// CHECK-HEADER-DIR-NOT: -internal-isystem
634+
// CHECK-HEADER-DIR-SAME: "-internal-isystem" "[[ROOT]]bin{{[/\\]+}}..{{[/\\]+}}include{{[/\\]+}}sycl{{[/\\]+}}stl_wrappers"
635+
// CHECK-HEADER-DIR-NOT: -internal-isystem
636+
// CHECK-HEADER-DIR-SAME: "-internal-isystem" "[[ROOT]]bin{{[/\\]+}}..{{[/\\]+}}include"
627637

628638
/// Check for option incompatibility with -fsycl
629639
// RUN: not %clang -### -fsycl -ffreestanding %s 2>&1 \

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,19 +1212,6 @@ __CLC_BF16_SCAL_VEC(uint32_t)
12121212
#undef __CLC_BF16_SCAL_VEC
12131213
#undef __CLC_BF16
12141214

1215-
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
1216-
__SYCL_EXPORT __spv::complex_half
1217-
__spirv_GroupCMulINTEL(unsigned int, unsigned int,
1218-
__spv::complex_half) noexcept;
1219-
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
1220-
__SYCL_EXPORT __spv::complex_float
1221-
__spirv_GroupCMulINTEL(unsigned int, unsigned int,
1222-
__spv::complex_float) noexcept;
1223-
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
1224-
__SYCL_EXPORT __spv::complex_double
1225-
__spirv_GroupCMulINTEL(unsigned int, unsigned int,
1226-
__spv::complex_double) noexcept;
1227-
12281215
extern __DPCPP_SYCL_EXTERNAL int32_t __spirv_BuiltInGlobalHWThreadIDINTEL();
12291216
extern __DPCPP_SYCL_EXTERNAL int32_t __spirv_BuiltInSubDeviceIDINTEL();
12301217

sycl/include/CL/__spirv/spirv_types.hpp

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include <sycl/detail/defines.hpp> // for SYCL_EXT_ONEAPI_MATRIX_VERSION
1212
#include <sycl/half_type.hpp> // for half
1313

14-
#include <complex> // for complex
1514
#include <cstddef> // for size_t
1615
#include <cstdint> // for uint32_t
1716

@@ -130,27 +129,6 @@ enum class MatrixLayout : uint32_t {
130129

131130
enum class MatrixUse : uint32_t { MatrixA = 0, MatrixB = 1, Accumulator = 2 };
132131

133-
struct complex_float {
134-
complex_float() = default;
135-
complex_float(std::complex<float> x) : real(x.real()), imag(x.imag()) {}
136-
operator std::complex<float>() { return {real, imag}; }
137-
float real, imag;
138-
};
139-
140-
struct complex_double {
141-
complex_double() = default;
142-
complex_double(std::complex<double> x) : real(x.real()), imag(x.imag()) {}
143-
operator std::complex<double>() { return {real, imag}; }
144-
double real, imag;
145-
};
146-
147-
struct complex_half {
148-
complex_half() = default;
149-
complex_half(std::complex<sycl::half> x) : real(x.real()), imag(x.imag()) {}
150-
operator std::complex<sycl::half>() { return {real, imag}; }
151-
sycl::half real, imag;
152-
};
153-
154132
#if (SYCL_EXT_ONEAPI_MATRIX_VERSION > 1)
155133
template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
156134
Scope::Flag S = Scope::Flag::Subgroup,

sycl/include/sycl/builtins.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
#include <sycl/builtins_scalar_gen.hpp>
1616
#include <sycl/builtins_vector_gen.hpp>
1717

18+
// We don't use the same exception specifier as <cmath> so we get warnings if
19+
// our code is processed before STL's <cmath>.
20+
// TODO: We should remove this dependency alltogether in a subsequent patch.
21+
#include <cmath>
22+
1823
#ifdef __SYCL_DEVICE_ONLY__
1924
extern "C" {
2025
extern __DPCPP_SYCL_EXTERNAL int abs(int x);

sycl/include/sycl/detail/generic_type_traits.hpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
#pragma once
1010

11-
#include <CL/__spirv/spirv_types.hpp> // for complex_double, comple...
1211
#include <sycl/access/access.hpp> // for decorated, address_space
1312
#include <sycl/aliases.hpp> // for half, cl_char, cl_double
1413
#include <sycl/detail/generic_type_lists.hpp> // for nonconst_address_space...
@@ -18,7 +17,6 @@
1817
#include <sycl/half_type.hpp> // for BIsRepresentationT
1918
#include <sycl/multi_ptr.hpp> // for multi_ptr, address_spa...
2019

21-
#include <complex> // for complex
2220
#include <cstddef> // for byte
2321
#include <cstdint> // for uint8_t
2422
#include <limits> // for numeric_limits
@@ -485,13 +483,16 @@ using select_cl_scalar_float_t =
485483
select_apply_cl_scalar_t<T, std::false_type, sycl::opencl::cl_half,
486484
sycl::opencl::cl_float, sycl::opencl::cl_double>;
487485

486+
// Use SFINAE so that std::complex specialization could be implemented in
487+
// include/sycl/stl_wrappers/complex that would only be available if STL's
488+
// <complex> is included by users.
489+
template <typename T, typename = void> struct select_cl_scalar_complex_or_T {
490+
using type = T;
491+
};
492+
488493
template <typename T>
489-
using select_cl_scalar_complex_or_T_t = std::conditional_t<
490-
std::is_same_v<T, std::complex<float>>, __spv::complex_float,
491-
std::conditional_t<std::is_same_v<T, std::complex<double>>,
492-
__spv::complex_double,
493-
std::conditional_t<std::is_same_v<T, std::complex<half>>,
494-
__spv::complex_half, T>>>;
494+
using select_cl_scalar_complex_or_T_t =
495+
typename select_cl_scalar_complex_or_T<T>::type;
495496

496497
template <typename T>
497498
using select_cl_scalar_integral_t =

sycl/include/sycl/ext/oneapi/functional.hpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,13 @@ struct GroupOpTag<T, std::enable_if_t<detail::is_sgenfloat<T>::value>> {
5353
using type = GroupOpFP;
5454
};
5555

56-
template <typename T>
57-
struct GroupOpTag<
58-
T, std::enable_if_t<std::is_same<T, std::complex<half>>::value ||
59-
std::is_same<T, std::complex<float>>::value ||
60-
std::is_same<T, std::complex<double>>::value>> {
61-
using type = GroupOpC;
62-
};
63-
6456
template <typename T>
6557
struct GroupOpTag<T, std::enable_if_t<detail::is_genbool<T>::value>> {
6658
using type = GroupOpBool;
6759
};
6860

61+
// GroupOpC (std::complex) is handled in sycl/stl_wrappers/complex.
62+
6963
#define __SYCL_CALC_OVERLOAD(GroupTag, SPIRVOperation, BinaryOperation) \
7064
template <__spv::GroupOperation O, typename Group, typename T> \
7165
static T calc(Group g, GroupTag, T x, BinaryOperation) { \

sycl/include/sycl/ext/oneapi/sub_group_mask.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <sycl/exception.hpp> // for errc, exception
1313
#include <sycl/id.hpp> // for id
1414
#include <sycl/marray.hpp> // for marray
15+
#include <sycl/types.hpp> // for vec
1516

1617
#include <assert.h> // for assert
1718
#include <climits> // for CHAR_BIT

sycl/include/sycl/group_algorithm.hpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
#endif
3232
#endif
3333

34-
#include <complex> // for complex
3534
#include <stddef.h> // for size_t
3635
#include <type_traits> // for enable_if_t, decay_t, integra...
3736

@@ -127,13 +126,11 @@ using is_multiplies = std::integral_constant<
127126
std::is_same_v<BinaryOperation, sycl::multiplies<void>>>;
128127

129128
// ---- is_complex
130-
// NOTE: std::complex<long double> not yet supported by group algorithms.
131-
template <typename T>
132-
struct is_complex
133-
: std::integral_constant<bool,
134-
std::is_same_v<T, std::complex<half>> ||
135-
std::is_same_v<T, std::complex<float>> ||
136-
std::is_same_v<T, std::complex<double>>> {};
129+
// Use SFINAE so that the "true" branch could be implemented in
130+
// include/sycl/stl_wrappers/complex that would only be available if STL's
131+
// <complex> is included by users.
132+
template <typename T, typename = void>
133+
struct is_complex : public std::false_type {};
137134

138135
// ---- is_arithmetic_or_complex
139136
template <typename T>

sycl/include/sycl/known_identity.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include <sycl/marray.hpp> // for marray
1616
#include <sycl/types.hpp> // for vec
1717

18-
#include <complex> // for complex
1918
#include <cstddef> // for byte, size_t
2019
#include <functional> // for logical_and, logical_or
2120
#include <limits> // for numeric_limits
@@ -75,9 +74,11 @@ using IsLogicalOR =
7574
std::is_same_v<BinaryOperation, sycl::logical_or<T>> ||
7675
std::is_same_v<BinaryOperation, sycl::logical_or<void>>>;
7776

78-
template <typename T>
79-
using isComplex = std::bool_constant<std::is_same_v<T, std::complex<float>> ||
80-
std::is_same_v<T, std::complex<double>>>;
77+
// Use SFINAE so that the "true" branch could be implemented in
78+
// include/sycl/stl_wrappers/complex that would only be available if STL's
79+
// <complex> is included by users.
80+
template <typename T, typename = void>
81+
struct isComplex : public std::false_type {};
8182

8283
// Identity = 0
8384
template <typename T, class BinaryOperation>
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
//==---------------- <complex> wrapper around STL --------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// STL's <complex> includes <cmath> which, in turn, pollutes global namespace.
10+
// As such, we cannot include <complex> from SYCL headers unconditionally and
11+
// have to provide support for std::complex only when the customer included
12+
// <complex> explicitly. Do that by providing our own <complex> that is
13+
// implemented as a wrapper around the STL header using "#include_next"
14+
// functionality.
15+
16+
#pragma once
17+
18+
// Include real STL <complex> header - the next one from the include search
19+
// directories.
20+
#if defined(__has_include_next)
21+
// GCC/clang support go through this path.
22+
#include_next <complex>
23+
#else
24+
// MSVC doesn't support "#include_next", so we have to be creative.
25+
// Our header is located in "stl_wrappers/complex" so it won't be picked by the
26+
// following include. MSVC's installation, on the other hand, has the layout
27+
// where the following would result in the <complex> we want. This is obviously
28+
// hacky, but the best we can do...
29+
#include <../include/complex>
30+
#endif
31+
32+
// Now that we have std::complex available, implement SYCL functionality related
33+
// to it.
34+
35+
#include <type_traits>
36+
37+
#include <CL/__spirv/spirv_ops.hpp> // for __SYCL_CONVERGENT__
38+
#include <sycl/half_type.hpp> // for half
39+
40+
// We provide std::complex specializations here for the following:
41+
// select_cl_scalar_complex_or_T:
42+
#include <sycl/detail/generic_type_traits.hpp>
43+
// sycl::detail::GroupOpTag:
44+
#include <sycl/ext/oneapi/functional.hpp>
45+
// sycl::detail::is_complex:
46+
#include <sycl/group_algorithm.hpp>
47+
// sycl::detail::isComplex
48+
#include <sycl/known_identity.hpp>
49+
50+
namespace __spv {
51+
struct complex_float {
52+
complex_float() = default;
53+
complex_float(std::complex<float> x) : real(x.real()), imag(x.imag()) {}
54+
operator std::complex<float>() { return {real, imag}; }
55+
float real, imag;
56+
};
57+
58+
struct complex_double {
59+
complex_double() = default;
60+
complex_double(std::complex<double> x) : real(x.real()), imag(x.imag()) {}
61+
operator std::complex<double>() { return {real, imag}; }
62+
double real, imag;
63+
};
64+
65+
struct complex_half {
66+
complex_half() = default;
67+
complex_half(std::complex<sycl::half> x) : real(x.real()), imag(x.imag()) {}
68+
operator std::complex<sycl::half>() { return {real, imag}; }
69+
sycl::half real, imag;
70+
};
71+
} // namespace __spv
72+
73+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
74+
__SYCL_EXPORT __spv::complex_half
75+
__spirv_GroupCMulINTEL(unsigned int, unsigned int,
76+
__spv::complex_half) noexcept;
77+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
78+
__SYCL_EXPORT __spv::complex_float
79+
__spirv_GroupCMulINTEL(unsigned int, unsigned int,
80+
__spv::complex_float) noexcept;
81+
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
82+
__SYCL_EXPORT __spv::complex_double
83+
__spirv_GroupCMulINTEL(unsigned int, unsigned int,
84+
__spv::complex_double) noexcept;
85+
86+
namespace sycl {
87+
inline namespace _V1 {
88+
namespace detail {
89+
template <typename T>
90+
struct isComplex<T, std::enable_if_t<std::is_same_v<T, std::complex<float>> ||
91+
std::is_same_v<T, std::complex<double>>>>
92+
: public std::true_type {};
93+
94+
// NOTE: std::complex<long double> not yet supported by group algorithms.
95+
template <typename T>
96+
struct is_complex<T, std::enable_if_t<std::is_same_v<T, std::complex<half>> ||
97+
std::is_same_v<T, std::complex<float>> ||
98+
std::is_same_v<T, std::complex<double>>>>
99+
: public std::true_type {};
100+
101+
#ifdef __SYCL_DEVICE_ONLY__
102+
template <typename T>
103+
struct GroupOpTag<
104+
T, std::enable_if_t<std::is_same<T, std::complex<half>>::value ||
105+
std::is_same<T, std::complex<float>>::value ||
106+
std::is_same<T, std::complex<double>>::value>> {
107+
using type = GroupOpC;
108+
};
109+
#endif
110+
111+
template <typename T>
112+
struct select_cl_scalar_complex_or_T<T,
113+
std::enable_if_t<is_complex<T>::value>> {
114+
using type = std::conditional_t<
115+
std::is_same_v<T, std::complex<float>>, __spv::complex_float,
116+
std::conditional_t<std::is_same_v<T, std::complex<double>>,
117+
__spv::complex_double, __spv::complex_half>>;
118+
};
119+
} // namespace detail
120+
} // namespace _V1
121+
} // namespace sycl

0 commit comments

Comments
 (0)