Skip to content

[SYCL][CUDA] Add basic sub-group functionality #2587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions libclc/generic/include/spirv/spirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,18 @@
#include <macros.h>

/* 6.11.1 Work-Item Functions */
#include <spirv/workitem/get_global_size.h>
#include <spirv/workitem/get_global_id.h>
#include <spirv/workitem/get_local_size.h>
#include <spirv/workitem/get_global_offset.h>
#include <spirv/workitem/get_global_size.h>
#include <spirv/workitem/get_group_id.h>
#include <spirv/workitem/get_local_id.h>
#include <spirv/workitem/get_local_size.h>
#include <spirv/workitem/get_max_sub_group_size.h>
#include <spirv/workitem/get_num_groups.h>
#include <spirv/workitem/get_group_id.h>
#include <spirv/workitem/get_global_offset.h>
#include <spirv/workitem/get_num_sub_groups.h>
#include <spirv/workitem/get_sub_group_id.h>
#include <spirv/workitem/get_sub_group_local_id.h>
#include <spirv/workitem/get_sub_group_size.h>
#include <spirv/workitem/get_work_dim.h>

/* 6.11.2.1 Floating-point macros */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupMaxSize();
9 changes: 9 additions & 0 deletions libclc/generic/include/spirv/workitem/get_num_sub_groups.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

_CLC_DEF _CLC_OVERLOAD uint __spirv_NumSubgroups();
9 changes: 9 additions & 0 deletions libclc/generic/include/spirv/workitem/get_sub_group_id.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupId();
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupLocalInvocationId();
9 changes: 9 additions & 0 deletions libclc/generic/include/spirv/workitem/get_sub_group_size.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupSize();
5 changes: 5 additions & 0 deletions libclc/ptx-nvidiacl/libspirv/SOURCES
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ workitem/get_global_size.cl
workitem/get_group_id.cl
workitem/get_local_id.cl
workitem/get_local_size.cl
workitem/get_max_sub_group_size.cl
workitem/get_num_groups.cl
workitem/get_num_sub_groups.cl
workitem/get_sub_group_id.cl
workitem/get_sub_group_local_id.cl
workitem/get_sub_group_size.cl
images/image_helpers.ll
images/image.cl
15 changes: 15 additions & 0 deletions libclc/ptx-nvidiacl/libspirv/workitem/get_max_sub_group_size.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <spirv/spirv.h>

_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupMaxSize() {
return 32;
// FIXME: warpsize is defined by NVVM IR but doesn't compile if used here
// return __nvvm_read_ptx_sreg_warpsize();
}
20 changes: 20 additions & 0 deletions libclc/ptx-nvidiacl/libspirv/workitem/get_num_sub_groups.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <spirv/spirv.h>

_CLC_DEF _CLC_OVERLOAD uint __spirv_NumSubgroups() {
// sreg.nwarpid returns number of warp identifiers, not number of warps
// see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
size_t size_x = __spirv_WorkgroupSize_x();
size_t size_y = __spirv_WorkgroupSize_y();
size_t size_z = __spirv_WorkgroupSize_z();
uint sg_size = __spirv_SubgroupMaxSize();
uint linear_size = size_z * size_y * size_x;
return (linear_size + sg_size - 1) / sg_size;
}
22 changes: 22 additions & 0 deletions libclc/ptx-nvidiacl/libspirv/workitem/get_sub_group_id.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <spirv/spirv.h>

_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupId() {
// sreg.warpid is volatile and doesn't represent virtual warp index
// see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
size_t id_x = __spirv_LocalInvocationId_x();
size_t id_y = __spirv_LocalInvocationId_y();
size_t id_z = __spirv_LocalInvocationId_z();
size_t size_x = __spirv_WorkgroupSize_x();
size_t size_y = __spirv_WorkgroupSize_y();
size_t size_z = __spirv_WorkgroupSize_z();
uint sg_size = __spirv_SubgroupMaxSize();
return (id_z * size_y * size_x + id_y * size_x + id_x) / sg_size;
}
13 changes: 13 additions & 0 deletions libclc/ptx-nvidiacl/libspirv/workitem/get_sub_group_local_id.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <spirv/spirv.h>

_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupLocalInvocationId() {
return __nvvm_read_ptx_sreg_laneid();
}
22 changes: 22 additions & 0 deletions libclc/ptx-nvidiacl/libspirv/workitem/get_sub_group_size.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <spirv/spirv.h>

_CLC_DEF _CLC_OVERLOAD uint __spirv_SubgroupSize() {
if (__spirv_SubgroupId() != __spirv_NumSubgroups() - 1) {
return __spirv_SubgroupMaxSize();
}
size_t size_x = __spirv_WorkgroupSize_x();
size_t size_y = __spirv_WorkgroupSize_y();
size_t size_z = __spirv_WorkgroupSize_z();
uint linear_size = size_z * size_y * size_x;
uint uniform_groups = __spirv_NumSubgroups() - 1;
uint uniform_size = __spirv_SubgroupMaxSize() * uniform_groups;
return linear_size - uniform_size;
}
35 changes: 28 additions & 7 deletions sycl/include/CL/__spirv/spirv_vars.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ SYCL_EXTERNAL size_t __spirv_LocalInvocationId_x();
SYCL_EXTERNAL size_t __spirv_LocalInvocationId_y();
SYCL_EXTERNAL size_t __spirv_LocalInvocationId_z();

SYCL_EXTERNAL uint32_t __spirv_SubgroupSize();
SYCL_EXTERNAL uint32_t __spirv_SubgroupMaxSize();
SYCL_EXTERNAL uint32_t __spirv_NumSubgroups();
SYCL_EXTERNAL uint32_t __spirv_SubgroupId();
SYCL_EXTERNAL uint32_t __spirv_SubgroupLocalInvocationId();

#else // __SYCL_NVPTX__

typedef size_t size_t_vec __attribute__((ext_vector_type(3)));
Expand All @@ -56,6 +62,12 @@ __SPIRV_VAR_QUALIFIERS size_t_vec __spirv_BuiltInLocalInvocationId;
__SPIRV_VAR_QUALIFIERS size_t_vec __spirv_BuiltInWorkgroupId;
__SPIRV_VAR_QUALIFIERS size_t_vec __spirv_BuiltInGlobalOffset;

__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupSize;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupMaxSize;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInNumSubgroups;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupId;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupLocalInvocationId;

SYCL_EXTERNAL inline size_t __spirv_GlobalInvocationId_x() {
return __spirv_BuiltInGlobalInvocationId.x;
}
Expand Down Expand Up @@ -126,14 +138,23 @@ SYCL_EXTERNAL inline size_t __spirv_LocalInvocationId_z() {
return __spirv_BuiltInLocalInvocationId.z;
}

#endif // __SYCL_NVPTX__
SYCL_EXTERNAL inline uint32_t __spirv_SubgroupSize() {
return __spirv_BuiltInSubgroupSize;
}
SYCL_EXTERNAL inline uint32_t __spirv_SubgroupMaxSize() {
return __spirv_BuiltInSubgroupMaxSize;
}
SYCL_EXTERNAL inline uint32_t __spirv_NumSubgroups() {
return __spirv_BuiltInNumSubgroups;
}
SYCL_EXTERNAL inline uint32_t __spirv_SubgroupId() {
return __spirv_BuiltInSubgroupId;
}
SYCL_EXTERNAL inline uint32_t __spirv_SubgroupLocalInvocationId() {
return __spirv_BuiltInSubgroupLocalInvocationId;
}

__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupSize;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupMaxSize;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInNumSubgroups;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInNumEnqueuedSubgroups;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupId;
__SPIRV_VAR_QUALIFIERS uint32_t __spirv_BuiltInSubgroupLocalInvocationId;
#endif // __SYCL_NVPTX__

#undef __SPIRV_VAR_QUALIFIERS

Expand Down
10 changes: 5 additions & 5 deletions sycl/include/CL/sycl/ONEAPI/sub_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ struct sub_group {

id_type get_local_id() const {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_BuiltInSubgroupLocalInvocationId;
return __spirv_SubgroupLocalInvocationId();
#else
throw runtime_error("Sub-groups are not supported on host device.",
PI_INVALID_DEVICE);
Expand All @@ -127,7 +127,7 @@ struct sub_group {

range_type get_local_range() const {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_BuiltInSubgroupSize;
return __spirv_SubgroupSize();
#else
throw runtime_error("Sub-groups are not supported on host device.",
PI_INVALID_DEVICE);
Expand All @@ -136,7 +136,7 @@ struct sub_group {

range_type get_max_local_range() const {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_BuiltInSubgroupMaxSize;
return __spirv_SubgroupMaxSize();
#else
throw runtime_error("Sub-groups are not supported on host device.",
PI_INVALID_DEVICE);
Expand All @@ -145,7 +145,7 @@ struct sub_group {

id_type get_group_id() const {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_BuiltInSubgroupId;
return __spirv_SubgroupId();
#else
throw runtime_error("Sub-groups are not supported on host device.",
PI_INVALID_DEVICE);
Expand All @@ -163,7 +163,7 @@ struct sub_group {

range_type get_group_range() const {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_BuiltInNumSubgroups;
return __spirv_NumSubgroups();
#else
throw runtime_error("Sub-groups are not supported on host device.",
PI_INVALID_DEVICE);
Expand Down
5 changes: 1 addition & 4 deletions sycl/test/sub_group/common.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// UNSUPPORTED: cuda
// CUDA compilation and runtime do not yet support sub-groups.
//
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
// RUN: env SYCL_DEVICE_TYPE=HOST %t.out
// RUN: %CPU_RUN_PLACEHOLDER %t.out
Expand Down Expand Up @@ -70,7 +67,7 @@ void check(queue &Queue, unsigned int G, unsigned int L) {
}
int main() {
queue Queue;
if (!core_sg_supported(Queue.get_device())) {
if (Queue.get_device().is_host()) {
std::cout << "Skipping test\n";
return 0;
}
Expand Down