Skip to content

Commit ee85233

Browse files
alexeyvoronov-intelbader
authored andcommitted
[SYCL][NFC] Refactor SYCL kernel invocation methods (#514)
Made the enable_if logic more readable. Moved the responsibility for indexers initialization to the Builder class. Applied naming convention. Signed-off-by: Alexey Voronov <[email protected]>
1 parent fe3bbf9 commit ee85233

File tree

3 files changed

+150
-144
lines changed

3 files changed

+150
-144
lines changed

sycl/include/CL/sycl/detail/helpers.hpp

Lines changed: 103 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
#pragma once
1010

1111
#include <CL/__spirv/spirv_types.hpp>
12+
#include <CL/__spirv/spirv_vars.hpp>
1213
#include <CL/sycl/access/access.hpp>
1314
#include <CL/sycl/detail/common.hpp>
1415
#include <CL/sycl/detail/pi.hpp>
16+
#include <CL/sycl/detail/type_traits.hpp>
1517

1618
#include <memory>
1719
#include <stdexcept>
@@ -22,13 +24,13 @@ namespace cl {
2224
namespace sycl {
2325
class context;
2426
class event;
25-
template <int dimensions, bool with_offset> class item;
26-
template <int dimensions> class group;
27-
template <int dimensions> class range;
28-
template <int dimensions> class id;
29-
template <int dimensions> class nd_item;
27+
template <int Dims, bool WithOffset> class item;
28+
template <int Dims> class group;
29+
template <int Dims> class range;
30+
template <int Dims> class id;
31+
template <int Dims> class nd_item;
32+
template <int Dims> class h_item;
3033
enum class memory_order;
31-
template <int dimensions> class h_item;
3234

3335
namespace detail {
3436
class context_impl;
@@ -43,70 +45,119 @@ void waitEvents(std::vector<cl::sycl::event> DepEvents);
4345
class Builder {
4446
public:
4547
Builder() = delete;
46-
template <int dimensions>
47-
static group<dimensions>
48-
createGroup(const cl::sycl::range<dimensions> &G,
49-
const cl::sycl::range<dimensions> &L,
50-
const cl::sycl::range<dimensions> &GroupRange,
51-
const cl::sycl::id<dimensions> &I) {
52-
return cl::sycl::group<dimensions>(G, L, GroupRange, I);
48+
49+
template <int Dims>
50+
static group<Dims>
51+
createGroup(const range<Dims> &Global, const range<Dims> &Local,
52+
const range<Dims> &Group, const id<Dims> &Index) {
53+
return group<Dims>(Global, Local, Group, Index);
54+
}
55+
56+
template <int Dims>
57+
static group<Dims> createGroup(const range<Dims> &Global,
58+
const range<Dims> &Local,
59+
const id<Dims> &Index) {
60+
return group<Dims>(Global, Local, Global / Local, Index);
61+
}
62+
63+
template <int Dims, bool WithOffset>
64+
static detail::enable_if_t<WithOffset, item<Dims, WithOffset>>
65+
createItem(const range<Dims> &Extent, const id<Dims> &Index,
66+
const id<Dims> &Offset) {
67+
return item<Dims, WithOffset>(Extent, Index, Offset);
68+
}
69+
70+
template <int Dims, bool WithOffset>
71+
static detail::enable_if_t<!WithOffset, item<Dims, WithOffset>>
72+
createItem(const range<Dims> &Extent, const id<Dims> &Index) {
73+
return item<Dims, WithOffset>(Extent, Index);
5374
}
5475

55-
template <int dimensions>
56-
static group<dimensions> createGroup(const cl::sycl::range<dimensions> &G,
57-
const cl::sycl::range<dimensions> &L,
58-
const cl::sycl::id<dimensions> &I) {
59-
return cl::sycl::group<dimensions>(G, L, G / L, I);
76+
template <int Dims>
77+
static nd_item<Dims> createNDItem(const item<Dims, true> &Global,
78+
const item<Dims, false> &Local,
79+
const group<Dims> &Group) {
80+
return nd_item<Dims>(Global, Local, Group);
6081
}
6182

62-
template <int dimensions, bool with_offset>
63-
static item<dimensions, with_offset> createItem(
64-
typename std::enable_if<(with_offset == true),
65-
const cl::sycl::range<dimensions>>::type &R,
66-
const cl::sycl::id<dimensions> &I, const cl::sycl::id<dimensions> &O) {
67-
return cl::sycl::item<dimensions, with_offset>(R, I, O);
83+
template <int Dims>
84+
static h_item<Dims> createHItem(const item<Dims, false> &Global,
85+
const item<Dims, false> &Local) {
86+
return h_item<Dims>(Global, Local);
6887
}
6988

70-
template <int dimensions, bool with_offset>
71-
static item<dimensions, with_offset> createItem(
72-
typename std::enable_if<(with_offset == false),
73-
const cl::sycl::range<dimensions>>::type &R,
74-
const cl::sycl::id<dimensions> &I) {
75-
return cl::sycl::item<dimensions, with_offset>(R, I);
89+
template <int Dims>
90+
static h_item<Dims> createHItem(const item<Dims, false> &Global,
91+
const item<Dims, false> &Local,
92+
const range<Dims> &Flex) {
93+
return h_item<Dims>(Global, Local, Flex);
7694
}
7795

78-
template <int dimensions, bool with_offset>
79-
static void updateItemIndex(cl::sycl::item<dimensions, with_offset> &Item,
80-
const id<dimensions> &NextIndex) {
96+
template <int Dims, bool WithOffset>
97+
static void updateItemIndex(cl::sycl::item<Dims, WithOffset> &Item,
98+
const id<Dims> &NextIndex) {
8199
Item.MImpl.MIndex = NextIndex;
82100
}
83101

84-
template <int dimensions>
85-
static nd_item<dimensions>
86-
createNDItem(const cl::sycl::item<dimensions, true> &GL,
87-
const cl::sycl::item<dimensions, false> &L,
88-
const cl::sycl::group<dimensions> &GR) {
89-
return cl::sycl::nd_item<dimensions>(GL, L, GR);
102+
#ifdef __SYCL_DEVICE_ONLY__
103+
104+
template <int N>
105+
using is_valid_dimensions = std::integral_constant<bool, (N > 0) && (N < 4)>;
106+
107+
template <int Dims> static const id<Dims> getId() {
108+
static_assert(is_valid_dimensions<Dims>::value, "invalid dimensions");
109+
return __spirv::initGlobalInvocationId<Dims, id<Dims>>();
110+
}
111+
112+
template <int Dims> static const group<Dims> getGroup() {
113+
static_assert(is_valid_dimensions<Dims>::value, "invalid dimensions");
114+
range<Dims> GlobalSize{__spirv::initGlobalSize<Dims, range<Dims>>()};
115+
range<Dims> LocalSize{__spirv::initWorkgroupSize<Dims, range<Dims>>()};
116+
range<Dims> GroupRange{__spirv::initNumWorkgroups<Dims, range<Dims>>()};
117+
id<Dims> GroupId{__spirv::initWorkgroupId<Dims, id<Dims>>()};
118+
return createGroup<Dims>(GlobalSize, LocalSize, GroupRange, GroupId);
119+
}
120+
121+
template <int Dims, bool WithOffset>
122+
static detail::enable_if_t<WithOffset, const item<Dims, WithOffset>>
123+
getItem() {
124+
static_assert(is_valid_dimensions<Dims>::value, "invalid dimensions");
125+
id<Dims> GlobalId{__spirv::initGlobalInvocationId<Dims, id<Dims>>()};
126+
range<Dims> GlobalSize{__spirv::initGlobalSize<Dims, range<Dims>>()};
127+
id<Dims> GlobalOffset{__spirv::initGlobalOffset<Dims, id<Dims>>()};
128+
return createItem<Dims, true>(GlobalSize, GlobalId, GlobalOffset);
90129
}
91130

92-
template <int dimensions>
93-
static h_item<dimensions>
94-
createHItem(const cl::sycl::item<dimensions, false> &GlobalItem,
95-
const cl::sycl::item<dimensions, false> &LocalItem) {
96-
return cl::sycl::h_item<dimensions>(GlobalItem, LocalItem);
131+
template <int Dims, bool WithOffset>
132+
static detail::enable_if_t<!WithOffset, const item<Dims, WithOffset>>
133+
getItem() {
134+
static_assert(is_valid_dimensions<Dims>::value, "invalid dimensions");
135+
id<Dims> GlobalId{__spirv::initGlobalInvocationId<Dims, id<Dims>>()};
136+
range<Dims> GlobalSize{__spirv::initGlobalSize<Dims, range<Dims>>()};
137+
return createItem<Dims, false>(GlobalSize, GlobalId);
97138
}
98139

99-
template <int dimensions>
100-
static h_item<dimensions>
101-
createHItem(const cl::sycl::item<dimensions, false> &GlobalItem,
102-
const cl::sycl::item<dimensions, false> &LocalItem,
103-
const cl::sycl::range<dimensions> &FlexRange) {
104-
return cl::sycl::h_item<dimensions>(GlobalItem, LocalItem, FlexRange);
140+
template <int Dims> static const nd_item<Dims> getNDItem() {
141+
static_assert(is_valid_dimensions<Dims>::value, "invalid dimensions");
142+
range<Dims> GlobalSize{__spirv::initGlobalSize<Dims, range<Dims>>()};
143+
range<Dims> LocalSize{__spirv::initWorkgroupSize<Dims, range<Dims>>()};
144+
range<Dims> GroupRange{__spirv::initNumWorkgroups<Dims, range<Dims>>()};
145+
id<Dims> GroupId{__spirv::initWorkgroupId<Dims, id<Dims>>()};
146+
id<Dims> GlobalId{__spirv::initGlobalInvocationId<Dims, id<Dims>>()};
147+
id<Dims> LocalId{__spirv::initLocalInvocationId<Dims, id<Dims>>()};
148+
id<Dims> GlobalOffset{__spirv::initGlobalOffset<Dims, id<Dims>>()};
149+
group<Dims> Group =
150+
createGroup<Dims>(GlobalSize, LocalSize, GroupRange, GroupId);
151+
item<Dims, true> GlobalItem =
152+
createItem<Dims, true>(GlobalSize, GlobalId, GlobalOffset);
153+
item<Dims, false> LocalItem = createItem<Dims, false>(LocalSize, LocalId);
154+
return createNDItem<Dims>(GlobalItem, LocalItem, Group);
105155
}
156+
#endif // __SYCL_DEVICE_ONLY__
106157
};
107158

108-
inline constexpr
109-
__spv::MemorySemanticsMask getSPIRVMemorySemanticsMask(memory_order) {
159+
inline constexpr __spv::MemorySemanticsMask
160+
getSPIRVMemorySemanticsMask(memory_order) {
110161
return __spv::MemorySemanticsMask::None;
111162
}
112163

sycl/include/CL/sycl/detail/type_traits.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ using allocator_pointer_t = typename std::allocator_traits<T>::pointer;
2525
template <bool B, class T = void>
2626
using enable_if_t = typename std::enable_if<B, T>::type;
2727

28+
template <bool B, class T, class F>
29+
using conditional_t = typename std::conditional<B, T, F>::type;
30+
2831
template <typename T>
2932
using remove_pointer_t = typename std::remove_pointer<T>::type;
3033

sycl/include/CL/sycl/handler.hpp

Lines changed: 44 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010

1111
#pragma once
1212

13-
#include <CL/__spirv/spirv_vars.hpp>
1413
#include <CL/sycl/access/access.hpp>
1514
#include <CL/sycl/context.hpp>
1615
#include <CL/sycl/detail/cg.hpp>
1716
#include <CL/sycl/detail/common.hpp>
17+
#include <CL/sycl/detail/helpers.hpp>
1818
#include <CL/sycl/detail/kernel_desc.hpp>
1919
#include <CL/sycl/detail/os_util.hpp>
2020
#include <CL/sycl/detail/scheduler/scheduler.hpp>
@@ -554,94 +554,64 @@ class handler {
554554
}
555555

556556
#ifdef __SYCL_DEVICE_ONLY__
557+
558+
template <typename KernelT, typename IndexerT>
559+
using EnableIfIndexer = detail::enable_if_t<
560+
std::is_same<detail::lambda_arg_type<KernelT>, IndexerT>::value>;
561+
562+
template <typename KernelT, int Dims>
563+
using EnableIfId = EnableIfIndexer<KernelT, id<Dims>>;
564+
565+
template <typename KernelT, int Dims>
566+
using EnableIfItemWithOffset = EnableIfIndexer<KernelT, item<Dims, true>>;
567+
568+
template <typename KernelT, int Dims>
569+
using EnableIfItemWithoutOffset = EnableIfIndexer<KernelT, item<Dims, false>>;
570+
571+
template <typename KernelT, int Dims>
572+
using EnableIfNDItem = EnableIfIndexer<KernelT, nd_item<Dims>>;
573+
557574
// NOTE: the name of this function - "kernel_single_task" - is used by the
558575
// Front End to determine kernel invocation kind.
559576
template <typename KernelName, typename KernelType>
560577
__attribute__((sycl_kernel)) void kernel_single_task(KernelType KernelFunc) {
561578
KernelFunc();
562579
}
563580

564-
template <typename KernelName, typename KernelType, int dimensions>
565-
__attribute__((sycl_kernel)) void kernel_parallel_for(
566-
typename std::enable_if<std::is_same<detail::lambda_arg_type<KernelType>,
567-
id<dimensions>>::value &&
568-
(dimensions > 0 && dimensions < 4),
569-
KernelType>::type KernelFunc) {
570-
id<dimensions> global_id{
571-
__spirv::initGlobalInvocationId<dimensions, id<dimensions>>()};
581+
// NOTE: the name of these functions - "kernel_parallel_for" - are used by the
582+
// Front End to determine kernel invocation kind.
583+
template <typename KernelName, typename KernelType, int Dims>
584+
__attribute__((sycl_kernel)) EnableIfId<KernelType, Dims>
585+
kernel_parallel_for(KernelType KernelFunc) {
586+
KernelFunc(detail::Builder::getId<Dims>());
587+
}
572588

573-
KernelFunc(global_id);
589+
template <typename KernelName, typename KernelType, int Dims>
590+
__attribute__((sycl_kernel)) EnableIfItemWithoutOffset<KernelType, Dims>
591+
kernel_parallel_for(KernelType KernelFunc) {
592+
KernelFunc(detail::Builder::getItem<Dims, false>());
574593
}
575594

576-
// NOTE: the name of this function - "kernel_parallel_for" - is used by the
577-
// Front End to determine kernel invocation kind.
578-
template <typename KernelName, typename KernelType, int dimensions>
579-
__attribute__((sycl_kernel)) void kernel_parallel_for(
580-
typename std::enable_if<std::is_same<detail::lambda_arg_type<KernelType>,
581-
item<dimensions, false>>::value &&
582-
(dimensions > 0 && dimensions < 4),
583-
KernelType>::type KernelFunc) {
584-
id<dimensions> global_id{
585-
__spirv::initGlobalInvocationId<dimensions, id<dimensions>>()};
586-
range<dimensions> global_size{
587-
__spirv::initGlobalSize<dimensions, range<dimensions>>()};
588-
589-
item<dimensions, false> Item =
590-
detail::Builder::createItem<dimensions, false>(global_size, global_id);
591-
KernelFunc(Item);
595+
template <typename KernelName, typename KernelType, int Dims>
596+
__attribute__((sycl_kernel)) EnableIfItemWithOffset<KernelType, Dims>
597+
kernel_parallel_for(KernelType KernelFunc) {
598+
KernelFunc(detail::Builder::getItem<Dims, true>());
592599
}
593600

594-
template <typename KernelName, typename KernelType, int dimensions>
595-
__attribute__((sycl_kernel)) void kernel_parallel_for(
596-
typename std::enable_if<std::is_same<detail::lambda_arg_type<KernelType>,
597-
item<dimensions, true>>::value &&
598-
(dimensions > 0 && dimensions < 4),
599-
KernelType>::type KernelFunc) {
600-
id<dimensions> global_id{
601-
__spirv::initGlobalInvocationId<dimensions, id<dimensions>>()};
602-
range<dimensions> global_size{
603-
__spirv::initGlobalSize<dimensions, range<dimensions>>()};
604-
id<dimensions> global_offset{
605-
__spirv::initGlobalOffset<dimensions, id<dimensions>>()};
606-
607-
item<dimensions, true> Item = detail::Builder::createItem<dimensions, true>(
608-
global_size, global_id, global_offset);
609-
KernelFunc(Item);
601+
template <typename KernelName, typename KernelType, int Dims>
602+
__attribute__((sycl_kernel)) EnableIfNDItem<KernelType, Dims>
603+
kernel_parallel_for(KernelType KernelFunc) {
604+
KernelFunc(detail::Builder::getNDItem<Dims>());
610605
}
611606

612-
template <typename KernelName, typename KernelType, int dimensions>
613-
__attribute__((sycl_kernel)) void kernel_parallel_for(
614-
typename std::enable_if<std::is_same<detail::lambda_arg_type<KernelType>,
615-
nd_item<dimensions>>::value &&
616-
(dimensions > 0 && dimensions < 4),
617-
KernelType>::type KernelFunc) {
618-
range<dimensions> global_size{
619-
__spirv::initGlobalSize<dimensions, range<dimensions>>()};
620-
range<dimensions> local_size{
621-
__spirv::initWorkgroupSize<dimensions, range<dimensions>>()};
622-
range<dimensions> group_range{
623-
__spirv::initNumWorkgroups<dimensions, range<dimensions>>()};
624-
id<dimensions> group_id{
625-
__spirv::initWorkgroupId<dimensions, id<dimensions>>()};
626-
id<dimensions> global_id{
627-
__spirv::initGlobalInvocationId<dimensions, id<dimensions>>()};
628-
id<dimensions> local_id{
629-
__spirv::initLocalInvocationId<dimensions, id<dimensions>>()};
630-
id<dimensions> global_offset{
631-
__spirv::initGlobalOffset<dimensions, id<dimensions>>()};
632-
633-
group<dimensions> Group = detail::Builder::createGroup<dimensions>(
634-
global_size, local_size, group_range, group_id);
635-
item<dimensions, true> globalItem =
636-
detail::Builder::createItem<dimensions, true>(global_size, global_id,
637-
global_offset);
638-
item<dimensions, false> localItem =
639-
detail::Builder::createItem<dimensions, false>(local_size, local_id);
640-
nd_item<dimensions> Nd_item =
641-
detail::Builder::createNDItem<dimensions>(globalItem, localItem, Group);
642-
643-
KernelFunc(Nd_item);
607+
// NOTE: the name of this function - "kernel_parallel_for_work_group" - is
608+
// used by the Front End to determine kernel invocation kind.
609+
template <typename KernelName, typename KernelType, int Dims>
610+
__attribute__((sycl_kernel)) void
611+
kernel_parallel_for_work_group(KernelType KernelFunc) {
612+
KernelFunc(detail::Builder::getGroup<Dims>());
644613
}
614+
645615
#endif
646616

647617
// The method stores lambda to the template-free object and initializes
@@ -740,24 +710,6 @@ class handler {
740710
#endif // __SYCL_DEVICE_ONLY__
741711
}
742712

743-
#ifdef __SYCL_DEVICE_ONLY__
744-
// NOTE: the name of this function - "kernel_parallel_for_work_group" - is
745-
// used by the Front End to determine kernel invocation kind.
746-
template <typename KernelName, typename KernelType, int Dims>
747-
__attribute__((sycl_kernel)) void
748-
kernel_parallel_for_work_group(KernelType KernelFunc) {
749-
750-
range<Dims> GlobalSize{__spirv::initGlobalSize<Dims, range<Dims>>()};
751-
range<Dims> LocalSize{__spirv::initWorkgroupSize<Dims, range<Dims>>()};
752-
range<Dims> GroupRange{__spirv::initNumWorkgroups<Dims, range<Dims>>()};
753-
id<Dims> GroupId{__spirv::initWorkgroupId<Dims, id<Dims>>()};
754-
755-
group<Dims> G = detail::Builder::createGroup<Dims>(GlobalSize, LocalSize,
756-
GroupRange, GroupId);
757-
KernelFunc(G);
758-
}
759-
#endif // __SYCL_DEVICE_ONLY__
760-
761713
template <typename KernelName = csd::auto_name, typename KernelType, int Dims>
762714
void parallel_for_work_group(range<Dims> NumWorkGroups,
763715
range<Dims> WorkGroupSize,

0 commit comments

Comments
 (0)