Skip to content

Commit 21b195c

Browse files
Implemented dynamic_work_groups, with a hack and few issues
1 parent 7f5da80 commit 21b195c

File tree

13 files changed

+239
-73
lines changed

13 files changed

+239
-73
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,12 +1606,12 @@ def SYCLType: InheritableAttr {
16061606
let Subjects = SubjectList<[CXXRecord, Enum], ErrorDiag>;
16071607
let LangOpts = [SYCLIsDevice, SilentlyIgnoreSYCLIsHost];
16081608
let Args = [EnumArgument<"Type", "SYCLType", /*is_string=*/true,
1609-
["accessor", "local_accessor", "work_group_memory",
1609+
["accessor", "local_accessor", "work_group_memory", "dynamic_work_group_memory",
16101610
"specialization_id", "kernel_handler", "buffer_location",
16111611
"no_alias", "accessor_property_list", "group",
16121612
"private_memory", "aspect", "annotated_ptr", "annotated_arg",
16131613
"stream", "sampler", "host_pipe", "multi_ptr"],
1614-
["accessor", "local_accessor", "work_group_memory",
1614+
["accessor", "local_accessor", "work_group_memory", "dynamic_work_group_memory",
16151615
"specialization_id", "kernel_handler", "buffer_location",
16161616
"no_alias", "accessor_property_list", "group",
16171617
"private_memory", "aspect", "annotated_ptr", "annotated_arg",

clang/include/clang/Sema/SemaSYCL.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ class SYCLIntegrationHeader {
6363
kind_specialization_constants_buffer,
6464
kind_stream,
6565
kind_work_group_memory,
66-
kind_last = kind_work_group_memory
66+
kind_dynamic_work_group_memory,
67+
kind_last = kind_dynamic_work_group_memory
6768
};
6869

6970
public:
@@ -666,7 +667,7 @@ class SemaSYCL : public SemaBase {
666667
// Used to check whether the function represented by FD is a SYCL
667668
// free function kernel or not.
668669
bool isFreeFunction(const FunctionDecl *FD);
669-
670+
670671
StmtResult BuildSYCLKernelCallStmt(FunctionDecl *FD, CompoundStmt *Body);
671672
};
672673

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 64 additions & 60 deletions
Large diffs are not rendered by default.

sycl-jit/common/include/Kernel.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ enum class ParameterKind : uint32_t {
6060
SpecConstBuffer = 4,
6161
Stream = 5,
6262
WorkGroupMemory = 6,
63+
DynamicWorkGroupMemory = 7,
6364
Invalid = 0xF,
6465
};
6566

@@ -239,8 +240,8 @@ class NDRange {
239240
NDRange(int Dimensions, const Indices &GlobalSize,
240241
const Indices &LocalSize = {1, 1, 1},
241242
const Indices &Offset = {0, 0, 0})
242-
: Dimensions{Dimensions},
243-
GlobalSize{GlobalSize}, LocalSize{LocalSize}, Offset{Offset} {
243+
: Dimensions{Dimensions}, GlobalSize{GlobalSize}, LocalSize{LocalSize},
244+
Offset{Offset} {
244245
#ifndef NDEBUG
245246
const auto CheckDim = [Dimensions](const Indices &Range) {
246247
return std::all_of(Range.begin() + Dimensions, Range.end(),

sycl/include/sycl/detail/kernel_desc.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ enum class kernel_param_kind_t {
5959
kind_specialization_constants_buffer = 4,
6060
kind_stream = 5,
6161
kind_work_group_memory = 6,
62+
kind_dynamic_work_group_memory = 7,
6263
kind_invalid = 0xf, // not a valid kernel kind
6364
};
6465

sycl/include/sycl/ext/oneapi/experimental/graph.hpp

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#pragma once
1010

11+
#include "sycl/ext/oneapi/experimental/graph.hpp"
12+
#include <cstddef>
1113
#include <sycl/accessor.hpp> // for detail::AccessorBaseHost
1214
#include <sycl/context.hpp> // for context
1315
#include <sycl/detail/export.hpp> // for __SYCL_EXPORT
@@ -17,8 +19,9 @@
1719
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
1820
#include <sycl/detail/string_view.hpp>
1921
#endif
20-
#include <sycl/device.hpp> // for device
22+
#include <sycl/device.hpp> // for device
2123
#include <sycl/ext/oneapi/experimental/detail/properties/graph_properties.hpp> // for graph properties classes
24+
#include <sycl/ext/oneapi/experimental/work_group_memory.hpp> // for dynamic_work_group_memory
2225
#include <sycl/nd_range.hpp> // for range, nd_range
2326
#include <sycl/properties/property_traits.hpp> // for is_property, is_property_of
2427
#include <sycl/property_list.hpp> // for property_list
@@ -485,6 +488,10 @@ class command_graph<graph_state::executable>
485488
namespace detail {
486489
class __SYCL_EXPORT dynamic_parameter_base {
487490
public:
491+
dynamic_parameter_base() = default;
492+
dynamic_parameter_base(
493+
sycl::ext::oneapi::experimental::command_graph<graph_state::modifiable>
494+
Graph);
488495
dynamic_parameter_base(
489496
sycl::ext::oneapi::experimental::command_graph<graph_state::modifiable>
490497
Graph,
@@ -509,14 +516,72 @@ class __SYCL_EXPORT dynamic_parameter_base {
509516
void updateValue(const raw_kernel_arg *NewRawValue, size_t Size);
510517

511518
void updateAccessor(const sycl::detail::AccessorBaseHost *Acc);
519+
520+
void updateWorkGroupMem(size_t BufferSize);
521+
512522
std::shared_ptr<dynamic_parameter_impl> impl;
513523

514524
template <class Obj>
515525
friend const decltype(Obj::impl) &
516526
sycl::detail::getSyclObjImpl(const Obj &SyclObject);
517527
};
528+
518529
} // namespace detail
519530

531+
template <typename T> struct is_unbounded_array : std::false_type {};
532+
533+
template <typename T> struct is_unbounded_array<T[]> : std::true_type {};
534+
535+
template <typename T>
536+
inline constexpr bool is_unbounded_array_v = is_unbounded_array<T>::value;
537+
538+
template <typename DataT,
539+
typename = std::enable_if_t<is_unbounded_array_v<DataT>>>
540+
541+
class __SYCL_SPECIAL_CLASS
542+
__SYCL_TYPE(dynamic_work_group_memory) dynamic_work_group_memory
543+
#ifdef __SYCL_DEVICE_ONLY__
544+
: detail::dynamic_parameter_base
545+
#else
546+
: public detail::dynamic_parameter_base
547+
#endif
548+
{
549+
public:
550+
dynamic_work_group_memory() = default;
551+
/// Constructs a new dynamic_work_group_memory object.
552+
/// @param Graph The graph associated with this object.
553+
/// @param Num Number of elements in the unbounded array DataT.
554+
dynamic_work_group_memory(
555+
experimental::command_graph<graph_state::modifiable> Graph, size_t Num) {
556+
auto &WorkGroupMemImpl =
557+
static_cast<detail::work_group_memory_impl &>(WorkGroupMem);
558+
WorkGroupMemImpl.buffer_size = Num * sizeof(std::remove_extent_t<DataT>);
559+
}
560+
561+
/// Updates this dynamic_work_group_memory and all registered nodes with a new
562+
/// number of elements.
563+
/// @param Num The new number of elements in the unbounded array.
564+
void update(size_t Num) {
565+
#ifndef __SYCL_DEVICE_ONLY__
566+
detail::dynamic_parameter_base::updateWorkGroupMem(
567+
Num * sizeof(std::remove_extent_t<DataT>));
568+
#endif
569+
}
570+
571+
const work_group_memory<DataT> &get() const { return WorkGroupMem; }
572+
573+
private:
574+
work_group_memory<DataT> WorkGroupMem;
575+
#ifdef __SYCL_DEVICE_ONLY__
576+
// [[maybe_unused]] char padding[sizeof(detail::dynamic_parameter_base)];
577+
using value_type = std::remove_all_extents_t<DataT>;
578+
using decoratedPtr = typename sycl::detail::DecoratedType<
579+
value_type, access::address_space::local_space>::type *;
580+
581+
void __init(decoratedPtr Ptr) { this->WorkGroupMem.__init(Ptr); }
582+
#endif
583+
};
584+
520585
template <typename ValueT>
521586
class dynamic_parameter : public detail::dynamic_parameter_base {
522587
static constexpr bool IsAccessor =

sycl/include/sycl/ext/oneapi/experimental/work_group_memory.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#pragma once
99

10+
#include "sycl/ext/oneapi/experimental/graph.hpp"
1011
#include <sycl/access/access.hpp>
1112
#include <sycl/detail/defines.hpp>
1213
#include <sycl/ext/oneapi/properties/properties.hpp>
@@ -19,6 +20,14 @@ namespace sycl {
1920
inline namespace _V1 {
2021
class handler;
2122

23+
namespace ext {
24+
namespace oneapi {
25+
namespace experimental {
26+
template <typename DataT, typename Enable> class dynamic_work_group_memory;
27+
}
28+
} // namespace oneapi
29+
} // namespace ext
30+
2231
namespace detail {
2332
template <typename T> struct is_unbounded_array : std::false_type {};
2433

@@ -38,6 +47,9 @@ class work_group_memory_impl {
3847
private:
3948
size_t buffer_size;
4049
friend class sycl::handler;
50+
51+
template <typename DataT, typename Enable>
52+
friend class sycl::ext::oneapi::experimental::dynamic_work_group_memory;
4153
};
4254

4355
} // namespace detail
@@ -115,6 +127,10 @@ class __SYCL_SPECIAL_CLASS __SYCL_TYPE(work_group_memory) work_group_memory
115127
friend class sycl::handler; // needed in order for handler class to be aware
116128
// of the private inheritance with
117129
// work_group_memory_impl as base class
130+
//
131+
template <typename T, typename Enable>
132+
friend class sycl::ext::oneapi::experimental::dynamic_work_group_memory;
133+
118134
decoratedPtr ptr = nullptr;
119135
};
120136
} // namespace ext::oneapi::experimental

sycl/include/sycl/handler.hpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,7 @@ class pipe;
152152
}
153153

154154
namespace ext ::oneapi ::experimental {
155-
template <typename, typename>
156-
class work_group_memory;
155+
template <typename, typename> class work_group_memory;
157156
struct image_descriptor;
158157
} // namespace ext::oneapi::experimental
159158

@@ -514,7 +513,8 @@ class __SYCL_EXPORT handler {
514513

515514
/// Saves the location of user's code passed in \p CodeLoc for future usage in
516515
/// finalize() method.
517-
/// TODO: remove the first version of this func (the one without the IsTopCodeLoc arg)
516+
/// TODO: remove the first version of this func (the one without the
517+
/// IsTopCodeLoc arg)
518518
/// at the next ABI breaking window since removing it breaks ABI on windows.
519519
void saveCodeLoc(detail::code_location CodeLoc);
520520
void saveCodeLoc(detail::code_location CodeLoc, bool IsTopCodeLoc);
@@ -724,8 +724,9 @@ class __SYCL_EXPORT handler {
724724
detail::KernelLambdaHasKernelHandlerArgT<KernelType,
725725
LambdaArgType>::value;
726726

727-
MHostKernel = std::make_unique<
728-
detail::HostKernel<KernelType, LambdaArgType, Dims>>(KernelFunc);
727+
MHostKernel =
728+
std::make_unique<detail::HostKernel<KernelType, LambdaArgType, Dims>>(
729+
KernelFunc);
729730

730731
constexpr bool KernelHasName =
731732
detail::getKernelName<KernelName>() != nullptr &&
@@ -3769,7 +3770,8 @@ class __SYCL_EXPORT handler {
37693770
"A local accessor must not be used in a SYCL kernel function "
37703771
"that is invoked via single_task or via the simple form of "
37713772
"parallel_for that takes a range parameter.");
3772-
if (Kind == detail::kernel_param_kind_t::kind_work_group_memory)
3773+
if (Kind == detail::kernel_param_kind_t::kind_work_group_memory ||
3774+
Kind == detail::kernel_param_kind_t::kind_dynamic_work_group_memory)
37733775
throw sycl::exception(
37743776
make_error_code(errc::kernel_argument),
37753777
"A work group memory object must not be used in a SYCL kernel "

sycl/source/detail/graph_impl.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1937,6 +1937,11 @@ void executable_command_graph::update(const std::vector<node> &Nodes) {
19371937
impl->update(NodeImpls);
19381938
}
19391939

1940+
dynamic_parameter_base::dynamic_parameter_base(
1941+
command_graph<graph_state::modifiable> Graph)
1942+
: impl(std::make_shared<dynamic_parameter_impl>(
1943+
sycl::detail::getSyclObjImpl(Graph))) {}
1944+
19401945
dynamic_parameter_base::dynamic_parameter_base(
19411946
command_graph<graph_state::modifiable> Graph, size_t ParamSize,
19421947
const void *Data)
@@ -1957,6 +1962,10 @@ void dynamic_parameter_base::updateAccessor(
19571962
impl->updateAccessor(Acc);
19581963
}
19591964

1965+
void dynamic_parameter_base::updateWorkGroupMem(size_t BufferSize) {
1966+
impl->updateWorkGroupMem(BufferSize);
1967+
}
1968+
19601969
void dynamic_parameter_impl::updateValue(const raw_kernel_arg *NewRawValue,
19611970
size_t Size) {
19621971
// Number of bytes is taken from member of raw_kernel_arg object rather
@@ -2012,6 +2021,40 @@ void dynamic_parameter_impl::updateAccessor(
20122021
sizeof(sycl::detail::AccessorBaseHost));
20132022
}
20142023

2024+
void dynamic_parameter_impl::updateWorkGroupMem(size_t BufferSize) {
2025+
for (auto &[NodeWeak, ArgIndex] : MNodes) {
2026+
auto NodeShared = NodeWeak.lock();
2027+
if (NodeShared) {
2028+
dynamic_parameter_impl::updateCGWorkGroupMem(NodeShared->MCommandGroup,
2029+
ArgIndex, BufferSize);
2030+
}
2031+
}
2032+
2033+
for (auto &DynCGInfo : MDynCGs) {
2034+
auto DynCG = DynCGInfo.DynCG.lock();
2035+
if (DynCG) {
2036+
auto &CG = DynCG->MKernels[DynCGInfo.CGIndex];
2037+
dynamic_parameter_impl::updateCGWorkGroupMem(CG, DynCGInfo.ArgIndex,
2038+
BufferSize);
2039+
}
2040+
}
2041+
}
2042+
2043+
void dynamic_parameter_impl::updateCGWorkGroupMem(
2044+
std::shared_ptr<sycl::detail::CG> CG, int ArgIndex, size_t BufferSize) {
2045+
2046+
auto &Args = static_cast<sycl::detail::CGExecKernel *>(CG.get())->MArgs;
2047+
for (auto &Arg : Args) {
2048+
if (Arg.MIndex != ArgIndex) {
2049+
continue;
2050+
}
2051+
assert(Arg.MType ==
2052+
sycl::detail::kernel_param_kind_t::kind_dynamic_work_group_memory);
2053+
Arg.MSize = BufferSize;
2054+
break;
2055+
}
2056+
}
2057+
20152058
void dynamic_parameter_impl::updateCGArgValue(
20162059
std::shared_ptr<sycl::detail::CG> CG, int ArgIndex, const void *NewValue,
20172060
size_t Size) {

sycl/source/detail/graph_impl.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,6 +1430,9 @@ class exec_graph_impl {
14301430

14311431
class dynamic_parameter_impl {
14321432
public:
1433+
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl)
1434+
: MGraph(GraphImpl) {}
1435+
14331436
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl,
14341437
size_t ParamSize, const void *Data)
14351438
: MGraph(GraphImpl), MValueStorage(ParamSize),
@@ -1497,6 +1500,22 @@ class dynamic_parameter_impl {
14971500
/// @param Acc The new accessor value
14981501
void updateAccessor(const sycl::detail::AccessorBaseHost *Acc);
14991502

1503+
/// Update the internal value of this dynamic parameter as well as the value
1504+
/// of this parameter in all registered nodes and dynamic CGs. Should only be
1505+
/// called for dynamic_work_group_memory arguments parameter.
1506+
/// @param BufferSize The total size in bytes of the new work_group_memory
1507+
/// array
1508+
void updateWorkGroupMem(size_t BufferSize);
1509+
1510+
/// Static helper function for updating command-group
1511+
/// dynamic_work_group_memory arguments.
1512+
/// @param CG The command-group to update the argument information for.
1513+
/// @param ArgIndex The argument index to update.
1514+
/// @param BufferSize The total size in bytes of the new work_group_memory
1515+
/// array
1516+
static void updateCGWorkGroupMem(std::shared_ptr<sycl::detail::CG> CG,
1517+
int ArgIndex, size_t BufferSize);
1518+
15001519
/// Static helper function for updating command-group value arguments.
15011520
/// @param CG The command-group to update the argument information for.
15021521
/// @param ArgIndex The argument index to update.

sycl/source/detail/jit_compiler.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ translateArgType(kernel_param_kind_t Kind) {
170170
return PK::Stream;
171171
case kind::kind_work_group_memory:
172172
return PK::WorkGroupMemory;
173+
case kind::kind_dynamic_work_group_memory:
174+
return PK::DynamicWorkGroupMemory;
173175
case kind::kind_invalid:
174176
return PK::Invalid;
175177
}

sycl/source/detail/scheduler/commands.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2352,6 +2352,8 @@ void SetArgBasedOnType(
23522352
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
23532353
const sycl::context &Context, detail::ArgDesc &Arg, size_t NextTrueIndex) {
23542354
switch (Arg.MType) {
2355+
case kernel_param_kind_t::kind_dynamic_work_group_memory:
2356+
break;
23552357
case kernel_param_kind_t::kind_work_group_memory:
23562358
break;
23572359
case kernel_param_kind_t::kind_stream:

sycl/source/handler.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,16 @@ void handler::processArg(void *Ptr, const detail::kernel_param_kind_t &Kind,
999999
}
10001000
break;
10011001
}
1002+
case kernel_param_kind_t::kind_dynamic_work_group_memory: {
1003+
1004+
auto *DynBase = static_cast<
1005+
ext::oneapi::experimental::detail::dynamic_parameter_base *>(Ptr);
1006+
1007+
registerDynamicParameter(*DynBase, Index + IndexShift);
1008+
1009+
Ptr = static_cast<void *>(++DynBase);
1010+
[[fallthrough]];
1011+
}
10021012
case kernel_param_kind_t::kind_work_group_memory: {
10031013
addArg(kernel_param_kind_t::kind_std_layout, nullptr,
10041014
static_cast<detail::work_group_memory_impl *>(Ptr)->buffer_size,

0 commit comments

Comments
 (0)