Skip to content

Commit baed6a5

Browse files
authored
[SYCL][CUDA] Add basic sub-group functionality (#2587)
Implements: - Sub-group local id - Sub-group id - Number of sub-groups - Sub-group size - Max sub-group size The implementations are functionally correct, but may benefit from additional optimization. Signed-off-by: John Pennycook <[email protected]> --- The implementation is different to the one proposed in https://intel.github.io/llvm-docs/cuda/opencl-subgroup-vs-cuda-crosslane-op.html, because I don't think `sreg.warpid` and `sreg.nwarpid` have the correct semantics for sub-groups. The mapping from work-items to sub-groups is invariant during a kernel's execution, which isn't true of the warp ID in PTX. As far as I can tell, the number of warp IDs represents the maximum number of warps that can execute in a CTA rather than the number of warps in a CTA. NVIDIA's [PTX documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#special-registers-warpid) says that `tid` should be used to compute a "virtual warp ID" if one is required, which is what I've implemented. I convinced myself that we have to compute the sub-group IDs and sizes from the linear size of the work-group, and couldn't find a simpler way to express this. Ideally, we wouldn't have to re-compute each of these values on every call. It would be sufficient to compute them once at the start of the kernel and then re-use them, but I don't have enough knowledge of Clang/LLVM/libclc to implement that.
1 parent 15cac43 commit baed6a5

File tree

15 files changed

+185
-20
lines changed

15 files changed

+185
-20
lines changed

libclc/generic/include/spirv/spirv.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,18 @@
3737
#include <macros.h>
3838

3939
/* 6.11.1 Work-Item Functions */
40-
#include <spirv/workitem/get_global_size.h>
4140
#include <spirv/workitem/get_global_id.h>
42-
#include <spirv/workitem/get_local_size.h>
41+
#include <spirv/workitem/get_global_offset.h>
42+
#include <spirv/workitem/get_global_size.h>
43+
#include <spirv/workitem/get_group_id.h>
4344
#include <spirv/workitem/get_local_id.h>
45+
#include <spirv/workitem/get_local_size.h>
46+
#include <spirv/workitem/get_max_sub_group_size.h>
4447
#include <spirv/workitem/get_num_groups.h>
45-
#include <spirv/workitem/get_group_id.h>
46-
#include <spirv/workitem/get_global_offset.h>
48+
#include <spirv/workitem/get_num_sub_groups.h>
49+
#include <spirv/workitem/get_sub_group_id.h>
50+
#include <spirv/workitem/get_sub_group_local_id.h>
51+
#include <spirv/workitem/get_sub_group_size.h>
4752
#include <spirv/workitem/get_work_dim.h>
4853

4954
/* 6.11.2.1 Floating-point macros */
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupMaxSize();
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
_CLC_DEF _CLC_OVERLOAD uint __spirv_NumSubgroups();
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupId();
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupLocalInvocationId();
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupSize();

libclc/ptx-nvidiacl/libspirv/SOURCES

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ workitem/get_global_size.cl
7575
workitem/get_group_id.cl
7676
workitem/get_local_id.cl
7777
workitem/get_local_size.cl
78+
workitem/get_max_sub_group_size.cl
7879
workitem/get_num_groups.cl
80+
workitem/get_num_sub_groups.cl
81+
workitem/get_sub_group_id.cl
82+
workitem/get_sub_group_local_id.cl
83+
workitem/get_sub_group_size.cl
7984
images/image_helpers.ll
8085
images/image.cl
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <spirv/spirv.h>
10+
11+
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupMaxSize() {
12+
return 32;
13+
// FIXME: warpsize is defined by NVVM IR but doesn't compile if used here
14+
// return __nvvm_read_ptx_sreg_warpsize();
15+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <spirv/spirv.h>
10+
11+
_CLC_DEF _CLC_OVERLOAD uint __spirv_NumSubgroups() {
12+
// sreg.nwarpid returns number of warp identifiers, not number of warps
13+
// see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
14+
size_t size_x = __spirv_WorkgroupSize_x();
15+
size_t size_y = __spirv_WorkgroupSize_y();
16+
size_t size_z = __spirv_WorkgroupSize_z();
17+
uint sg_size = __spirv_SubgroupMaxSize();
18+
uint linear_size = size_z * size_y * size_x;
19+
return (linear_size + sg_size - 1) / sg_size;
20+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <spirv/spirv.h>
10+
11+
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupId() {
12+
// sreg.warpid is volatile and doesn't represent virtual warp index
13+
// see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
14+
size_t id_x = __spirv_LocalInvocationId_x();
15+
size_t id_y = __spirv_LocalInvocationId_y();
16+
size_t id_z = __spirv_LocalInvocationId_z();
17+
size_t size_x = __spirv_WorkgroupSize_x();
18+
size_t size_y = __spirv_WorkgroupSize_y();
19+
size_t size_z = __spirv_WorkgroupSize_z();
20+
uint sg_size = __spirv_SubgroupMaxSize();
21+
return (id_z * size_y * size_x + id_y * size_x + id_x) / sg_size;
22+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <spirv/spirv.h>
10+
11+
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupLocalInvocationId() {
12+
return __nvvm_read_ptx_sreg_laneid();
13+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <spirv/spirv.h>
10+
11+
_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupSize() {
12+
if (__spirv_SubgroupId() != __spirv_NumSubgroups() - 1) {
13+
return __spirv_SubgroupMaxSize();
14+
}
15+
size_t size_x = __spirv_WorkgroupSize_x();
16+
size_t size_y = __spirv_WorkgroupSize_y();
17+
size_t size_z = __spirv_WorkgroupSize_z();
18+
uint linear_size = size_z * size_y * size_x;
19+
uint uniform_groups = __spirv_NumSubgroups() - 1;
20+
uint uniform_size = __spirv_SubgroupMaxSize() * uniform_groups;
21+
return linear_size - uniform_size;
22+
}

sycl/include/CL/__spirv/spirv_vars.hpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ SYCL_EXTERNAL size_t __spirv_LocalInvocationId_x();
4545
SYCL_EXTERNAL size_t __spirv_LocalInvocationId_y();
4646
SYCL_EXTERNAL size_t __spirv_LocalInvocationId_z();
4747

48+
SYCL_EXTERNAL uint32_t __spirv_SubgroupSize();
49+
SYCL_EXTERNAL uint32_t __spirv_SubgroupMaxSize();
50+
SYCL_EXTERNAL uint32_t __spirv_NumSubgroups();
51+
SYCL_EXTERNAL uint32_t __spirv_SubgroupId();
52+
SYCL_EXTERNAL uint32_t __spirv_SubgroupLocalInvocationId();
53+
4854
#else // __SYCL_NVPTX__
4955

5056
typedef size_t size_t_vec __attribute__((ext_vector_type(3)));
@@ -56,6 +62,12 @@ __SPIRV_VAR_QUALIFIERS size_t_vec __spirv_BuiltInLocalInvocationId;
5662
__SPIRV_VAR_QUALIFIERS size_t_vec __spirv_BuiltInWorkgroupId;
5763
__SPIRV_VAR_QUALIFIERS size_t_vec __spirv_BuiltInGlobalOffset;
5864

65+
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupSize;
66+
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupMaxSize;
67+
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInNumSubgroups;
68+
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupId;
69+
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupLocalInvocationId;
70+
5971
SYCL_EXTERNAL inline size_t __spirv_GlobalInvocationId_x() {
6072
return __spirv_BuiltInGlobalInvocationId.x;
6173
}
@@ -126,14 +138,23 @@ SYCL_EXTERNAL inline size_t __spirv_LocalInvocationId_z() {
126138
return __spirv_BuiltInLocalInvocationId.z;
127139
}
128140

129-
#endif // __SYCL_NVPTX__
141+
SYCL_EXTERNAL inline uint32_t __spirv_SubgroupSize() {
142+
return __spirv_BuiltInSubgroupSize;
143+
}
144+
SYCL_EXTERNAL inline uint32_t __spirv_SubgroupMaxSize() {
145+
return __spirv_BuiltInSubgroupMaxSize;
146+
}
147+
SYCL_EXTERNAL inline uint32_t __spirv_NumSubgroups() {
148+
return __spirv_BuiltInNumSubgroups;
149+
}
150+
SYCL_EXTERNAL inline uint32_t __spirv_SubgroupId() {
151+
return __spirv_BuiltInSubgroupId;
152+
}
153+
SYCL_EXTERNAL inline uint32_t __spirv_SubgroupLocalInvocationId() {
154+
return __spirv_BuiltInSubgroupLocalInvocationId;
155+
}
130156

131-
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupSize;
132-
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupMaxSize;
133-
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInNumSubgroups;
134-
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInNumEnqueuedSubgroups;
135-
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupId;
136-
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupLocalInvocationId;
157+
#endif // __SYCL_NVPTX__
137158

138159
#undef __SPIRV_VAR_QUALIFIERS
139160

sycl/include/CL/sycl/ONEAPI/sub_group.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ struct sub_group {
109109

110110
id_type get_local_id() const {
111111
#ifdef __SYCL_DEVICE_ONLY__
112-
return __spirv_BuiltInSubgroupLocalInvocationId;
112+
return __spirv_SubgroupLocalInvocationId();
113113
#else
114114
throw runtime_error("Sub-groups are not supported on host device.",
115115
PI_INVALID_DEVICE);
@@ -127,7 +127,7 @@ struct sub_group {
127127

128128
range_type get_local_range() const {
129129
#ifdef __SYCL_DEVICE_ONLY__
130-
return __spirv_BuiltInSubgroupSize;
130+
return __spirv_SubgroupSize();
131131
#else
132132
throw runtime_error("Sub-groups are not supported on host device.",
133133
PI_INVALID_DEVICE);
@@ -136,7 +136,7 @@ struct sub_group {
136136

137137
range_type get_max_local_range() const {
138138
#ifdef __SYCL_DEVICE_ONLY__
139-
return __spirv_BuiltInSubgroupMaxSize;
139+
return __spirv_SubgroupMaxSize();
140140
#else
141141
throw runtime_error("Sub-groups are not supported on host device.",
142142
PI_INVALID_DEVICE);
@@ -145,7 +145,7 @@ struct sub_group {
145145

146146
id_type get_group_id() const {
147147
#ifdef __SYCL_DEVICE_ONLY__
148-
return __spirv_BuiltInSubgroupId;
148+
return __spirv_SubgroupId();
149149
#else
150150
throw runtime_error("Sub-groups are not supported on host device.",
151151
PI_INVALID_DEVICE);
@@ -163,7 +163,7 @@ struct sub_group {
163163

164164
range_type get_group_range() const {
165165
#ifdef __SYCL_DEVICE_ONLY__
166-
return __spirv_BuiltInNumSubgroups;
166+
return __spirv_NumSubgroups();
167167
#else
168168
throw runtime_error("Sub-groups are not supported on host device.",
169169
PI_INVALID_DEVICE);

sycl/test/sub_group/common.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
// UNSUPPORTED: cuda
2-
// CUDA compilation and runtime do not yet support sub-groups.
3-
//
41
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
52
// RUN: env SYCL_DEVICE_TYPE=HOST %t.out
63
// RUN: %CPU_RUN_PLACEHOLDER %t.out
@@ -70,7 +67,7 @@ void check(queue &Queue, unsigned int G, unsigned int L) {
7067
}
7168
int main() {
7269
queue Queue;
73-
if (!core_sg_supported(Queue.get_device())) {
70+
if (Queue.get_device().is_host()) {
7471
std::cout << "Skipping test\n";
7572
return 0;
7673
}

0 commit comments

Comments
 (0)