Skip to content

Commit 644013e

Browse files
rolandschulzromanovvlad
authored andcommitted
[SYCL] Add range, id, and nd_range CTAD support (#772)
Also replace one enable_if with static_assert for better diagnostic. Signed-off-by: Roland Schulz <[email protected]>
1 parent c514d25 commit 644013e

File tree

5 files changed

+68
-4
lines changed

5 files changed

+68
-4
lines changed

sycl/include/CL/sycl/id.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,5 +160,14 @@ size_t getOffsetForId(range<dimensions> Range, id<dimensions> Id,
160160
return offset;
161161
}
162162
} // namespace detail
163+
164+
// C++ feature test macros are supported by all supported compilers
165+
// with the exception of MSVC 1914. It doesn't support deduction guides.
166+
#ifdef __cpp_deduction_guides
167+
id(size_t)->id<1>;
168+
id(size_t, size_t)->id<2>;
169+
id(size_t, size_t, size_t)->id<3>;
170+
#endif
171+
163172
} // namespace sycl
164173
} // namespace cl

sycl/include/CL/sycl/nd_range.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ template <int dimensions = 1> class nd_range {
2020
range<dimensions> globalSize;
2121
range<dimensions> localSize;
2222
id<dimensions> offset;
23+
static_assert(dimensions >= 1 && dimensions <= 3,
24+
"nd_range can only be 1, 2, or 3 dimensional.");
2325

2426
public:
25-
template <int N = dimensions>
26-
nd_range(
27-
typename std::enable_if<((N > 0) && (N < 4)), range<dimensions>>::type globalSize,
28-
range<dimensions> localSize, id<dimensions> offset = id<dimensions>())
27+
nd_range(range<dimensions> globalSize, range<dimensions> localSize,
28+
id<dimensions> offset = id<dimensions>())
2929
: globalSize(globalSize), localSize(localSize), offset(offset) {}
3030

3131
range<dimensions> get_global_range() const { return globalSize; }

sycl/include/CL/sycl/range.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,5 +134,11 @@ template <int dimensions = 1> class range : public detail::array<dimensions> {
134134
#undef __SYCL_GEN_OPT
135135
};
136136

137+
#ifdef __cpp_deduction_guides
138+
range(size_t)->range<1>;
139+
range(size_t, size_t)->range<2>;
140+
range(size_t, size_t, size_t)->range<3>;
141+
#endif
142+
137143
} // namespace sycl
138144
} // namespace cl

sycl/test/basic_tests/id_ctad.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: %clangxx -std=c++17 -fsyntax-only -Xclang -verify %s
2+
// expected-no-diagnostics
3+
//==--------------- id_ctad.cpp - SYCL id CTAD test ----------------------==//
4+
//
5+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6+
// See https://llvm.org/LICENSE.txt for license information.
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
//===----------------------------------------------------------------------===//
10+
#include <CL/sycl.hpp>
11+
12+
using namespace std;
13+
int main() {
14+
cl::sycl::id one_dim_id(64);
15+
cl::sycl::id two_dim_id(64, 1);
16+
cl::sycl::id three_dim_id(64, 1, 2);
17+
static_assert(std::is_same_v<decltype(one_dim_id), cl::sycl::id<1>>);
18+
static_assert(std::is_same_v<decltype(two_dim_id), cl::sycl::id<2>>);
19+
static_assert(std::is_same_v<decltype(three_dim_id), cl::sycl::id<3>>);
20+
}

sycl/test/basic_tests/range_ctad.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: %clangxx -std=c++17 -fsyntax-only -Xclang -verify %s
2+
// expected-no-diagnostics
3+
//==--------------- range_ctad.cpp - SYCL range CTAD test ----------------------==//
4+
//
5+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6+
// See https://llvm.org/LICENSE.txt for license information.
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
//===----------------------------------------------------------------------===//
10+
#include <CL/sycl.hpp>
11+
12+
using namespace std;
13+
int main() {
14+
cl::sycl::range one_dim_range(64);
15+
cl::sycl::range two_dim_range(64, 1);
16+
cl::sycl::range three_dim_range(64, 1, 2);
17+
static_assert(std::is_same_v<decltype(one_dim_range), cl::sycl::range<1>>);
18+
static_assert(std::is_same_v<decltype(two_dim_range), cl::sycl::range<2>>);
19+
static_assert(std::is_same_v<decltype(three_dim_range), cl::sycl::range<3>>);
20+
cl::sycl::nd_range one_dim_ndrange(one_dim_range, one_dim_range);
21+
cl::sycl::nd_range two_dim_ndrange(two_dim_range, two_dim_range);
22+
cl::sycl::nd_range three_dim_ndrange(three_dim_range, three_dim_range);
23+
static_assert(
24+
std::is_same_v<decltype(one_dim_ndrange), cl::sycl::nd_range<1>>);
25+
static_assert(
26+
std::is_same_v<decltype(two_dim_ndrange), cl::sycl::nd_range<2>>);
27+
static_assert(
28+
std::is_same_v<decltype(three_dim_ndrange), cl::sycl::nd_range<3>>);
29+
}

0 commit comments

Comments
 (0)