Skip to content

Commit 60f6e16

Browse files
authored
[SYCL] Implement work group memory extension (#15178)
Implement work group memory extension: https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_work_group_memory.asciidoc Two notes: - Free function kernel support for work group memory argument will be added in a future PR. - When the assignment operator is called in host code, the assigned to work group memory object does not actually correspond to the same underlying memory as the one that was assigned from contradicting the spec. See KhronosGroup/SYCL-Docs#552 for a similar problem with `local_accessor`
1 parent c7194c2 commit 60f6e16

File tree

21 files changed

+687
-11
lines changed

21 files changed

+687
-11
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,12 +1566,12 @@ def SYCLType: InheritableAttr {
15661566
let Subjects = SubjectList<[CXXRecord, Enum], ErrorDiag>;
15671567
let LangOpts = [SYCLIsDevice, SilentlyIgnoreSYCLIsHost];
15681568
let Args = [EnumArgument<"Type", "SYCLType", /*is_string=*/true,
1569-
["accessor", "local_accessor",
1569+
["accessor", "local_accessor", "work_group_memory",
15701570
"specialization_id", "kernel_handler", "buffer_location",
15711571
"no_alias", "accessor_property_list", "group",
15721572
"private_memory", "aspect", "annotated_ptr", "annotated_arg",
15731573
"stream", "sampler", "host_pipe", "multi_ptr"],
1574-
["accessor", "local_accessor",
1574+
["accessor", "local_accessor", "work_group_memory",
15751575
"specialization_id", "kernel_handler", "buffer_location",
15761576
"no_alias", "accessor_property_list", "group",
15771577
"private_memory", "aspect", "annotated_ptr", "annotated_arg",

clang/include/clang/Sema/SemaSYCL.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ class SYCLIntegrationHeader {
6262
kind_pointer,
6363
kind_specialization_constants_buffer,
6464
kind_stream,
65-
kind_last = kind_stream
65+
kind_work_group_memory,
66+
kind_last = kind_work_group_memory
6667
};
6768

6869
public:

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4693,6 +4693,9 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
46934693
CurOffset + offsetOf(FD, FieldTy));
46944694
} else if (SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::stream)) {
46954695
addParam(FD, FieldTy, SYCLIntegrationHeader::kind_stream);
4696+
} else if (SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::work_group_memory)) {
4697+
addParam(FieldTy, SYCLIntegrationHeader::kind_work_group_memory,
4698+
offsetOf(FD, FieldTy));
46964699
} else if (SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::sampler) ||
46974700
SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::annotated_ptr) ||
46984701
SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::annotated_arg)) {
@@ -5773,6 +5776,7 @@ static const char *paramKind2Str(KernelParamKind K) {
57735776
CASE(stream);
57745777
CASE(specialization_constants_buffer);
57755778
CASE(pointer);
5779+
CASE(work_group_memory);
57765780
}
57775781
return "<ERROR>";
57785782

clang/test/CodeGenSYCL/Inputs/sycl.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,24 @@ const stream& operator<<(const stream &S, T&&) {
649649
return S;
650650
}
651651

652+
// Dummy implementation of work_group_memory for use in CodeGenSYCL tests.
653+
template <typename DataT>
654+
class __attribute__((sycl_special_class))
655+
__SYCL_TYPE(work_group_memory) work_group_memory {
656+
public:
657+
work_group_memory(handler &CGH) {}
658+
#ifdef __SYCL_DEVICE_ONLY__
659+
// Default constructor for objects later initialized with __init member.
660+
work_group_memory() = default;
661+
#endif
662+
663+
void __init(__attribute((opencl_local)) DataT *Ptr) { this->Ptr = Ptr; }
664+
__attribute((opencl_local)) DataT *operator&() const { return Ptr; }
665+
666+
private:
667+
__attribute((opencl_local)) DataT *Ptr;
668+
};
669+
652670
template <typename T, int dimensions = 1,
653671
typename AllocatorT = int /*fake type as AllocatorT is not used*/>
654672
class buffer {
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: %clang_cc1 -fsycl-is-device -triple spir64-unknown-unknown -disable-llvm-passes -emit-llvm %s -o %t.ll
2+
// RUN: FileCheck < %t.ll %s --check-prefix CHECK-IR
3+
// RUN: %clang_cc1 -fsycl-is-device -triple spir64-unknown-unknown -fsycl-int-header=%t.h %s
4+
// RUN: FileCheck < %t.h %s --check-prefix CHECK-INT-HEADER
5+
//
6+
// Tests for work_group_memory kernel parameter using the dummy implementation in Inputs/sycl.hpp.
7+
// The first two RUN commands verify that the init call is generated with the correct arguments in LLVM IR
8+
// and the second two RUN commands verify the contents of the integration header produced by the frontend.
9+
//
10+
// CHECK-IR: define dso_local spir_kernel void @
11+
// CHECK-IR-SAME: ptr addrspace(3) noundef align 4 [[PTR:%[a-zA-Z0-9_]+]]
12+
//
13+
// CHECK-IR: [[PTR]].addr = alloca ptr addrspace(3), align 8
14+
// CHECK-IR: [[PTR]].addr.ascast = addrspacecast ptr [[PTR]].addr to ptr addrspace(4)
15+
// CHECK-IR: store ptr addrspace(3) [[PTR]], ptr addrspace(4) [[PTR]].addr.ascast, align 8
16+
// CHECK-IR: [[PTR_LOAD:%[a-zA-Z0-9_]+]] = load ptr addrspace(3), ptr addrspace(4) [[PTR]].addr.ascast, align 8
17+
//
18+
// CHECK-IR: call spir_func void @{{.*}}__init{{.*}}(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %{{[a-zA-Z0-9_]+}}, ptr addrspace(3) noundef [[PTR_LOAD]])
19+
//
20+
// CHECK-INT-HEADER: const kernel_param_desc_t kernel_signatures[] = {
21+
// CHECK-INT-HEADER-NEXT: //--- _ZTSZZ4mainENKUlRN4sycl3_V17handlerEE_clES2_EUlNS0_4itemILi1EEEE_
22+
// CHECK-INT-HEADER-NEXT: { kernel_param_kind_t::kind_work_group_memory, {{[4,8]}}, 0 },
23+
// CHECK-INT-HEADER-EMPTY:
24+
// CHECK-INT-HEADER-NEXT: { kernel_param_kind_t::kind_invalid, -987654321, -987654321 },
25+
// CHECK-INT-HEADER-NEXT: };
26+
27+
#include "Inputs/sycl.hpp"
28+
29+
int main() {
30+
sycl::queue Q;
31+
Q.submit([&](sycl::handler &CGH) {
32+
sycl::work_group_memory<int> mem;
33+
sycl::range<1> ndr;
34+
CGH.parallel_for(ndr, [=](sycl::item<1> it) { int *ptr = &mem; });
35+
});
36+
return 0;
37+
}

clang/test/SemaSYCL/Inputs/sycl/detail/kernel_desc.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ namespace detail {
1818
kind_pointer = 3,
1919
kind_specialization_constants_buffer = 4,
2020
kind_stream = 5,
21+
kind_work_group_memory = 6,
2122
kind_invalid = 0xf, // not a valid kernel kind
2223
};
2324

sycl-jit/common/include/Kernel.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ enum class ParameterKind : uint32_t {
5858
Pointer = 3,
5959
SpecConstBuffer = 4,
6060
Stream = 5,
61+
WorkGroupMemory = 6,
6162
Invalid = 0xF,
6263
};
6364

sycl/include/sycl/detail/kernel_desc.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ enum class kernel_param_kind_t {
5858
kind_pointer = 3,
5959
kind_specialization_constants_buffer = 4,
6060
kind_stream = 5,
61+
kind_work_group_memory = 6,
6162
kind_invalid = 0xf, // not a valid kernel kind
6263
};
6364

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
//===-------------------- work_group_memory.hpp ---------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
11+
#include <type_traits>
12+
13+
namespace sycl {
14+
inline namespace _V1 {
15+
namespace detail {
16+
template <typename T> struct is_unbounded_array : std::false_type {};
17+
18+
template <typename T> struct is_unbounded_array<T[]> : std::true_type {};
19+
20+
template <typename T>
21+
inline constexpr bool is_unbounded_array_v = is_unbounded_array<T>::value;
22+
23+
class work_group_memory_impl {
24+
public:
25+
work_group_memory_impl() : buffer_size{0} {}
26+
work_group_memory_impl(const work_group_memory_impl &rhs) = default;
27+
work_group_memory_impl &
28+
operator=(const work_group_memory_impl &rhs) = default;
29+
work_group_memory_impl(size_t buffer_size) : buffer_size{buffer_size} {}
30+
31+
private:
32+
size_t buffer_size;
33+
friend class sycl::handler;
34+
};
35+
36+
} // namespace detail
37+
namespace ext::oneapi::experimental {
38+
39+
template <typename DataT, typename PropertyListT = empty_properties_t>
40+
class __SYCL_SPECIAL_CLASS __SYCL_TYPE(work_group_memory) work_group_memory
41+
: sycl::detail::work_group_memory_impl {
42+
public:
43+
using value_type = std::remove_all_extents_t<DataT>;
44+
45+
private:
46+
using decoratedPtr = typename sycl::detail::DecoratedType<
47+
value_type, access::address_space::local_space>::type *;
48+
49+
public:
50+
work_group_memory() = default;
51+
work_group_memory(const work_group_memory &rhs) = default;
52+
work_group_memory &operator=(const work_group_memory &rhs) = default;
53+
template <typename T = DataT,
54+
typename = std::enable_if_t<!sycl::detail::is_unbounded_array_v<T>>>
55+
work_group_memory(handler &)
56+
: sycl::detail::work_group_memory_impl(sizeof(DataT)) {}
57+
template <typename T = DataT,
58+
typename = std::enable_if_t<sycl::detail::is_unbounded_array_v<T>>>
59+
work_group_memory(size_t num, handler &)
60+
: sycl::detail::work_group_memory_impl(
61+
num * sizeof(std::remove_extent_t<DataT>)) {}
62+
template <access::decorated IsDecorated = access::decorated::no>
63+
multi_ptr<value_type, access::address_space::local_space, IsDecorated>
64+
get_multi_ptr() const {
65+
return sycl::address_space_cast<access::address_space::local_space,
66+
IsDecorated, value_type>(ptr);
67+
}
68+
DataT *operator&() const { return reinterpret_cast<DataT *>(ptr); }
69+
operator DataT &() const { return *reinterpret_cast<DataT *>(ptr); }
70+
template <typename T = DataT,
71+
typename = std::enable_if_t<!std::is_array_v<T>>>
72+
const work_group_memory &operator=(const DataT &value) const {
73+
*ptr = value;
74+
return *this;
75+
}
76+
#ifdef __SYCL_DEVICE_ONLY__
77+
void __init(decoratedPtr ptr) { this->ptr = ptr; }
78+
#endif
79+
private:
80+
decoratedPtr ptr;
81+
};
82+
} // namespace ext::oneapi::experimental
83+
} // namespace _V1
84+
} // namespace sycl

sycl/include/sycl/handler.hpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ class pipe;
163163
}
164164

165165
namespace ext ::oneapi ::experimental {
166+
template <typename, typename>
167+
class work_group_memory;
166168
struct image_descriptor;
167169
} // namespace ext::oneapi::experimental
168170

@@ -171,6 +173,7 @@ class graph_impl;
171173
} // namespace ext::oneapi::experimental::detail
172174
namespace detail {
173175

176+
class work_group_memory_impl;
174177
class handler_impl;
175178
class kernel_impl;
176179
class queue_impl;
@@ -564,8 +567,8 @@ class __SYCL_EXPORT handler {
564567
// The version for regular(standard layout) argument.
565568
template <typename T, typename... Ts>
566569
void setArgsHelper(int ArgIndex, T &&Arg, Ts &&...Args) {
567-
set_arg(ArgIndex, std::move(Arg));
568-
setArgsHelper(++ArgIndex, std::move(Args)...);
570+
set_arg(ArgIndex, std::forward<T>(Arg));
571+
setArgsHelper(++ArgIndex, std::forward<Ts>(Args)...);
569572
}
570573

571574
void setArgsHelper(int) {}
@@ -603,6 +606,8 @@ class __SYCL_EXPORT handler {
603606
#endif
604607
}
605608

609+
void setArgHelper(int ArgIndex, detail::work_group_memory_impl &Arg);
610+
606611
// setArgHelper for non local accessor argument.
607612
template <typename DataT, int Dims, access::mode AccessMode,
608613
access::target AccessTarget, access::placeholder IsPlaceholder>
@@ -1096,7 +1101,7 @@ class __SYCL_EXPORT handler {
10961101
KernelType KernelFunc) {
10971102
#ifndef __SYCL_DEVICE_ONLY__
10981103
throwIfActionIsCreated();
1099-
throwOnLocalAccessorMisuse<KernelName, KernelType>();
1104+
throwOnKernelParameterMisuse<KernelName, KernelType>();
11001105
if (!range_size_fits_in_size_t(UserRange))
11011106
throw sycl::exception(make_error_code(errc::runtime),
11021107
"The total number of work-items in "
@@ -1641,7 +1646,7 @@ class __SYCL_EXPORT handler {
16411646
kernel_single_task_wrapper<NameT, KernelType, PropertiesT>(KernelFunc);
16421647
#ifndef __SYCL_DEVICE_ONLY__
16431648
throwIfActionIsCreated();
1644-
throwOnLocalAccessorMisuse<KernelName, KernelType>();
1649+
throwOnKernelParameterMisuse<KernelName, KernelType>();
16451650
verifyUsedKernelBundleInternal(
16461651
detail::string_view{detail::getKernelName<NameT>()});
16471652
// No need to check if range is out of INT_MAX limits as it's compile-time
@@ -1840,6 +1845,14 @@ class __SYCL_EXPORT handler {
18401845
setArgHelper(ArgIndex, std::move(Arg));
18411846
}
18421847

1848+
template <typename DataT, typename PropertyListT =
1849+
ext::oneapi::experimental::empty_properties_t>
1850+
void set_arg(
1851+
int ArgIndex,
1852+
ext::oneapi::experimental::work_group_memory<DataT, PropertyListT> &Arg) {
1853+
setArgHelper(ArgIndex, Arg);
1854+
}
1855+
18431856
// set_arg for graph dynamic_parameters
18441857
template <typename T>
18451858
void set_arg(int argIndex,
@@ -1858,9 +1871,8 @@ class __SYCL_EXPORT handler {
18581871
///
18591872
/// \param Args are argument values to be set.
18601873
template <typename... Ts> void set_args(Ts &&...Args) {
1861-
setArgsHelper(0, std::move(Args)...);
1874+
setArgsHelper(0, std::forward<Ts>(Args)...);
18621875
}
1863-
18641876
/// Defines and invokes a SYCL kernel function as a function object type.
18651877
///
18661878
/// If it is a named function object and the function object type is
@@ -3233,7 +3245,6 @@ class __SYCL_EXPORT handler {
32333245
private:
32343246
std::shared_ptr<detail::handler_impl> impl;
32353247
std::shared_ptr<detail::queue_impl> MQueue;
3236-
32373248
std::vector<detail::LocalAccessorImplPtr> MLocalAccStorage;
32383249
std::vector<std::shared_ptr<detail::stream_impl>> MStreamStorage;
32393250
detail::string MKernelName;
@@ -3554,7 +3565,7 @@ class __SYCL_EXPORT handler {
35543565
/// must not be used in a SYCL kernel function that is invoked via single_task
35553566
/// or via the simple form of parallel_for that takes a range parameter.
35563567
template <typename KernelName, typename KernelType>
3557-
void throwOnLocalAccessorMisuse() const {
3568+
void throwOnKernelParameterMisuse() const {
35583569
using NameT =
35593570
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
35603571
for (unsigned I = 0; I < detail::getKernelNumParams<NameT>(); ++I) {
@@ -3570,6 +3581,12 @@ class __SYCL_EXPORT handler {
35703581
"A local accessor must not be used in a SYCL kernel function "
35713582
"that is invoked via single_task or via the simple form of "
35723583
"parallel_for that takes a range parameter.");
3584+
if (Kind == detail::kernel_param_kind_t::kind_work_group_memory)
3585+
throw sycl::exception(
3586+
make_error_code(errc::kernel_argument),
3587+
"A work group memory object must not be used in a SYCL kernel "
3588+
"function that is invoked via single_task or via the simple form "
3589+
"of parallel_for that takes a range parameter.");
35733590
}
35743591
}
35753592

sycl/include/sycl/sycl.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
#include <sycl/ext/oneapi/experimental/raw_kernel_arg.hpp>
101101
#include <sycl/ext/oneapi/experimental/root_group.hpp>
102102
#include <sycl/ext/oneapi/experimental/tangle_group.hpp>
103+
#include <sycl/ext/oneapi/experimental/work_group_memory.hpp>
103104
#include <sycl/ext/oneapi/filter_selector.hpp>
104105
#include <sycl/ext/oneapi/free_function_queries.hpp>
105106
#include <sycl/ext/oneapi/functional.hpp>

sycl/source/detail/handler_impl.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,9 @@ class handler_impl {
197197

198198
/// True if MCodeLoc is sycl entry point code location
199199
bool MIsTopCodeLoc = true;
200+
201+
/// List of work group memory objects associated with this handler
202+
std::vector<std::shared_ptr<detail::work_group_memory_impl>> MWorkGroupMemoryObjects;
200203
};
201204

202205
} // namespace detail

sycl/source/detail/jit_compiler.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ translateArgType(kernel_param_kind_t Kind) {
133133
return PK::SpecConstBuffer;
134134
case kind::kind_stream:
135135
return PK::Stream;
136+
case kind::kind_work_group_memory:
137+
return PK::WorkGroupMemory;
136138
case kind::kind_invalid:
137139
return PK::Invalid;
138140
}

sycl/source/detail/scheduler/commands.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2297,6 +2297,8 @@ void SetArgBasedOnType(
22972297
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
22982298
const sycl::context &Context, detail::ArgDesc &Arg, size_t NextTrueIndex) {
22992299
switch (Arg.MType) {
2300+
case kernel_param_kind_t::kind_work_group_memory:
2301+
break;
23002302
case kernel_param_kind_t::kind_stream:
23012303
break;
23022304
case kernel_param_kind_t::kind_accessor: {

sycl/source/handler.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include <sycl/stream.hpp>
3535

3636
#include <sycl/ext/oneapi/bindless_images_memory.hpp>
37+
#include <sycl/ext/oneapi/experimental/work_group_memory.hpp>
3738
#include <sycl/ext/oneapi/memcpy2d.hpp>
3839

3940
namespace sycl {
@@ -795,6 +796,12 @@ void handler::processArg(void *Ptr, const detail::kernel_param_kind_t &Kind,
795796
}
796797
break;
797798
}
799+
case kernel_param_kind_t::kind_work_group_memory: {
800+
addArg(kernel_param_kind_t::kind_std_layout, nullptr,
801+
static_cast<detail::work_group_memory_impl *>(Ptr)->buffer_size,
802+
Index + IndexShift);
803+
break;
804+
}
798805
case kernel_param_kind_t::kind_sampler: {
799806
addArg(kernel_param_kind_t::kind_sampler, Ptr, sizeof(sampler),
800807
Index + IndexShift);
@@ -812,6 +819,13 @@ void handler::processArg(void *Ptr, const detail::kernel_param_kind_t &Kind,
812819
}
813820
}
814821

822+
void handler::setArgHelper(int ArgIndex, detail::work_group_memory_impl &Arg) {
823+
impl->MWorkGroupMemoryObjects.push_back(
824+
std::make_shared<detail::work_group_memory_impl>(Arg));
825+
addArg(detail::kernel_param_kind_t::kind_work_group_memory,
826+
impl->MWorkGroupMemoryObjects.back().get(), 0, ArgIndex);
827+
}
828+
815829
// The argument can take up more space to store additional information about
816830
// MAccessRange, MMemoryRange, and MOffset added with addArgsForGlobalAccessor.
817831
// We use the worst-case estimate because the lifetime of the vector is short.

0 commit comments

Comments
 (0)