Skip to content

Commit a21a741

Browse files
Implemented dynamic_work_group_memory with lambdas
1 parent 999c682 commit a21a741

File tree

15 files changed

+219
-6
lines changed

15 files changed

+219
-6
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,12 +1606,12 @@ def SYCLType: InheritableAttr {
16061606
let Subjects = SubjectList<[CXXRecord, Enum], ErrorDiag>;
16071607
let LangOpts = [SYCLIsDevice, SilentlyIgnoreSYCLIsHost];
16081608
let Args = [EnumArgument<"Type", "SYCLType", /*is_string=*/true,
1609-
["accessor", "local_accessor", "work_group_memory",
1609+
["accessor", "local_accessor", "work_group_memory", "dynamic_work_group_memory",
16101610
"specialization_id", "kernel_handler", "buffer_location",
16111611
"no_alias", "accessor_property_list", "group",
16121612
"private_memory", "aspect", "annotated_ptr", "annotated_arg",
16131613
"stream", "sampler", "host_pipe", "multi_ptr"],
1614-
["accessor", "local_accessor", "work_group_memory",
1614+
["accessor", "local_accessor", "work_group_memory", "dynamic_work_group_memory",
16151615
"specialization_id", "kernel_handler", "buffer_location",
16161616
"no_alias", "accessor_property_list", "group",
16171617
"private_memory", "aspect", "annotated_ptr", "annotated_arg",

clang/include/clang/Sema/SemaSYCL.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ class SYCLIntegrationHeader {
6363
kind_specialization_constants_buffer,
6464
kind_stream,
6565
kind_work_group_memory,
66-
kind_last = kind_work_group_memory
66+
kind_dynamic_work_group_memory,
67+
kind_last = kind_dynamic_work_group_memory
6768
};
6869

6970
public:
@@ -666,7 +667,7 @@ class SemaSYCL : public SemaBase {
666667
// Used to check whether the function represented by FD is a SYCL
667668
// free function kernel or not.
668669
bool isFreeFunction(const FunctionDecl *FD);
669-
670+
670671
StmtResult BuildSYCLKernelCallStmt(FunctionDecl *FD, CompoundStmt *Body);
671672
};
672673

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4846,6 +4846,10 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
48464846
} else if (SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::work_group_memory)) {
48474847
addParam(FieldTy, SYCLIntegrationHeader::kind_work_group_memory,
48484848
offsetOf(FD, FieldTy));
4849+
} else if (SemaSYCL::isSyclType(FieldTy,
4850+
SYCLTypeAttr::dynamic_work_group_memory)) {
4851+
addParam(FieldTy, SYCLIntegrationHeader::kind_dynamic_work_group_memory,
4852+
offsetOf(FD, FieldTy));
48494853
} else if (SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::sampler) ||
48504854
SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::annotated_ptr) ||
48514855
SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::annotated_arg)) {
@@ -5993,6 +5997,7 @@ static const char *paramKind2Str(KernelParamKind K) {
59935997
CASE(specialization_constants_buffer);
59945998
CASE(pointer);
59955999
CASE(work_group_memory);
6000+
CASE(dynamic_work_group_memory);
59966001
}
59976002
return "<ERROR>";
59986003

clang/test/SemaSYCL/Inputs/sycl/detail/kernel_desc.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ namespace detail {
1919
kind_specialization_constants_buffer = 4,
2020
kind_stream = 5,
2121
kind_work_group_memory = 6,
22+
kind_dynamic_work_group_memory = 7,
2223
kind_invalid = 0xf, // not a valid kernel kind
2324
};
2425

sycl-jit/common/include/Kernel.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ enum class ParameterKind : uint32_t {
6060
SpecConstBuffer = 4,
6161
Stream = 5,
6262
WorkGroupMemory = 6,
63+
DynamicWorkGroupMemory = 7,
6364
Invalid = 0xF,
6465
};
6566

sycl/include/sycl/detail/kernel_desc.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ enum class kernel_param_kind_t {
5959
kind_specialization_constants_buffer = 4,
6060
kind_stream = 5,
6161
kind_work_group_memory = 6,
62+
kind_dynamic_work_group_memory = 7,
6263
kind_invalid = 0xf, // not a valid kernel kind
6364
};
6465

sycl/include/sycl/ext/oneapi/experimental/graph.hpp

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010

11+
#include <cstddef>
1112
#include <sycl/accessor.hpp> // for detail::AccessorBaseHost
1213
#include <sycl/context.hpp> // for context
1314
#include <sycl/detail/export.hpp> // for __SYCL_EXPORT
@@ -17,8 +18,9 @@
1718
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
1819
#include <sycl/detail/string_view.hpp>
1920
#endif
20-
#include <sycl/device.hpp> // for device
21+
#include <sycl/device.hpp> // for device
2122
#include <sycl/ext/oneapi/experimental/detail/properties/graph_properties.hpp> // for graph properties classes
23+
#include <sycl/ext/oneapi/experimental/work_group_memory.hpp> // for dynamic_work_group_memory
2224
#include <sycl/nd_range.hpp> // for range, nd_range
2325
#include <sycl/properties/property_traits.hpp> // for is_property, is_property_of
2426
#include <sycl/property_list.hpp> // for property_list
@@ -501,6 +503,9 @@ class command_graph<graph_state::executable>
501503
namespace detail {
502504
class __SYCL_EXPORT dynamic_parameter_base {
503505
public:
506+
dynamic_parameter_base(
507+
sycl::ext::oneapi::experimental::command_graph<graph_state::modifiable>
508+
Graph);
504509
dynamic_parameter_base(
505510
sycl::ext::oneapi::experimental::command_graph<graph_state::modifiable>
506511
Graph,
@@ -525,14 +530,86 @@ class __SYCL_EXPORT dynamic_parameter_base {
525530
void updateValue(const raw_kernel_arg *NewRawValue, size_t Size);
526531

527532
void updateAccessor(const sycl::detail::AccessorBaseHost *Acc);
533+
534+
void updateWorkGroupMem(size_t BufferSize);
535+
528536
std::shared_ptr<dynamic_parameter_impl> impl;
529537

530538
template <class Obj>
531539
friend const decltype(Obj::impl) &
532540
sycl::detail::getSyclObjImpl(const Obj &SyclObject);
533541
};
542+
543+
class dynamic_work_group_memory_base
544+
#ifndef __SYCL_DEVICE_ONLY__
545+
: public dynamic_parameter_base
546+
#endif
547+
{
548+
public:
549+
dynamic_work_group_memory_base() = default;
550+
dynamic_work_group_memory_base(
551+
experimental::command_graph<graph_state::modifiable> Graph, size_t Size)
552+
:
553+
#ifndef __SYCL_DEVICE_ONLY__
554+
dynamic_parameter_base(Graph),
555+
#endif
556+
BufferSize(Size) {
557+
}
558+
559+
private:
560+
#ifdef __SYCL_DEVICE_ONLY__
561+
[[maybe_unused]] unsigned char Padding[sizeof(dynamic_parameter_base)];
562+
#endif
563+
size_t BufferSize{};
564+
friend class sycl::handler;
565+
};
534566
} // namespace detail
535567

568+
template <typename DataT,
569+
typename = std::enable_if_t<detail::is_unbounded_array_v<DataT>>>
570+
class __SYCL_SPECIAL_CLASS
571+
__SYCL_TYPE(dynamic_work_group_memory) dynamic_work_group_memory
572+
: public detail::dynamic_work_group_memory_base {
573+
private:
574+
work_group_memory<DataT> WorkGroupMem;
575+
576+
#ifdef __SYCL_DEVICE_ONLY__
577+
using value_type = std::remove_all_extents_t<DataT>;
578+
using decoratedPtr = typename sycl::detail::DecoratedType<
579+
value_type, access::address_space::local_space>::type *;
580+
581+
void __init(decoratedPtr Ptr) { this->WorkGroupMem.__init(Ptr); }
582+
#endif
583+
584+
public:
585+
/// Constructs a new dynamic_work_group_memory object.
586+
/// @param Graph The graph associated with this object.
587+
/// @param Num Number of elements in the unbounded array DataT.
588+
dynamic_work_group_memory(
589+
experimental::command_graph<graph_state::modifiable> Graph, size_t Num)
590+
: detail::dynamic_work_group_memory_base(
591+
Graph, Num * sizeof(std::remove_extent_t<DataT>)) {}
592+
593+
/// Updates on the host this dynamic_work_group_memory and all registered
594+
/// nodes with a new buffer size.
595+
/// @param Num The new number of elements in the unbounded array.
596+
void update(size_t Num) {
597+
#ifndef __SYCL_DEVICE_ONLY__
598+
detail::dynamic_parameter_base::updateWorkGroupMem(
599+
Num * sizeof(std::remove_extent_t<DataT>));
600+
#endif
601+
}
602+
work_group_memory<DataT> get() const { return WorkGroupMem; }
603+
604+
// Frontend requires special types to have a default constructor in order to
605+
// have a uniform way of initializing an object of special type to then call
606+
// the __init method on it. This is purely an implementation detail and not
607+
// part of the spec.
608+
// TODO: Revisit this once https://github.com/intel/llvm/issues/16061 is
609+
// closed.
610+
dynamic_work_group_memory() = default;
611+
};
612+
536613
template <typename ValueT>
537614
class dynamic_parameter : public detail::dynamic_parameter_base {
538615
static constexpr bool IsAccessor =

sycl/include/sycl/ext/oneapi/experimental/work_group_memory.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#pragma once
99

10+
#include "sycl/ext/oneapi/experimental/graph.hpp"
1011
#include <sycl/access/access.hpp>
1112
#include <sycl/detail/defines.hpp>
1213
#include <sycl/ext/oneapi/properties/properties.hpp>
@@ -115,6 +116,9 @@ class __SYCL_SPECIAL_CLASS __SYCL_TYPE(work_group_memory) work_group_memory
115116
friend class sycl::handler; // needed in order for handler class to be aware
116117
// of the private inheritance with
117118
// work_group_memory_impl as base class
119+
120+
template <typename, typename> friend class dynamic_work_group_memory;
121+
118122
decoratedPtr ptr = nullptr;
119123
};
120124
} // namespace ext::oneapi::experimental

sycl/include/sycl/handler.hpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ class pipe;
150150

151151
namespace ext ::oneapi ::experimental {
152152
template <typename, typename> class work_group_memory;
153+
template <typename, typename> class dynamic_work_group_memory;
153154
struct image_descriptor;
154155
__SYCL_EXPORT void async_free(sycl::handler &h, void *ptr);
155156
__SYCL_EXPORT void *async_malloc(sycl::handler &h, sycl::usm::alloc kind,
@@ -160,6 +161,8 @@ __SYCL_EXPORT void *async_malloc_from_pool(sycl::handler &h, size_t size,
160161

161162
namespace ext::oneapi::experimental::detail {
162163
class graph_impl;
164+
class dynamic_work_group_memory_base;
165+
class dynamic_parameter_base;
163166
} // namespace ext::oneapi::experimental::detail
164167
namespace detail {
165168

@@ -680,6 +683,12 @@ class __SYCL_EXPORT handler {
680683
registerDynamicParameter(DynamicParam, ArgIndex);
681684
}
682685

686+
// setArgHelper for graph dynamic_work_group_memory
687+
void
688+
setArgHelper(int ArgIndex,
689+
ext::oneapi::experimental::detail::dynamic_work_group_memory_base
690+
&DynWorkGroupMemParam);
691+
683692
// setArgHelper for the raw_kernel_arg extension type.
684693
void setArgHelper(int ArgIndex,
685694
sycl::ext::oneapi::experimental::raw_kernel_arg &&Arg) {
@@ -1879,6 +1888,16 @@ class __SYCL_EXPORT handler {
18791888
setArgHelper(argIndex, dynamicParam);
18801889
}
18811890

1891+
// set_arg for graph dynamic_work_group_memory
1892+
template <typename DataT>
1893+
void set_arg(int argIndex,
1894+
ext::oneapi::experimental::dynamic_work_group_memory<DataT>
1895+
&dynWorkGroupMem) {
1896+
ext::oneapi::experimental::detail::dynamic_work_group_memory_base
1897+
&dynWorkGroupMemImpl = dynWorkGroupMem;
1898+
setArgHelper(argIndex, dynWorkGroupMemImpl);
1899+
}
1900+
18821901
// set_arg for the raw_kernel_arg extension type.
18831902
void set_arg(int argIndex, ext::oneapi::experimental::raw_kernel_arg &&Arg) {
18841903
setArgHelper(argIndex, std::move(Arg));
@@ -3771,7 +3790,8 @@ class __SYCL_EXPORT handler {
37713790
"A local accessor must not be used in a SYCL kernel function "
37723791
"that is invoked via single_task or via the simple form of "
37733792
"parallel_for that takes a range parameter.");
3774-
if (Kind == detail::kernel_param_kind_t::kind_work_group_memory)
3793+
if (Kind == detail::kernel_param_kind_t::kind_work_group_memory ||
3794+
Kind == detail::kernel_param_kind_t::kind_dynamic_work_group_memory)
37753795
throw sycl::exception(
37763796
make_error_code(errc::kernel_argument),
37773797
"A work group memory object must not be used in a SYCL kernel "

sycl/source/detail/graph_impl.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1994,6 +1994,11 @@ void executable_command_graph::update(const std::vector<node> &Nodes) {
19941994
impl->update(NodeImpls);
19951995
}
19961996

1997+
dynamic_parameter_base::dynamic_parameter_base(
1998+
command_graph<graph_state::modifiable> Graph)
1999+
: impl(std::make_shared<dynamic_parameter_impl>(
2000+
sycl::detail::getSyclObjImpl(Graph))) {}
2001+
19972002
dynamic_parameter_base::dynamic_parameter_base(
19982003
command_graph<graph_state::modifiable> Graph, size_t ParamSize,
19992004
const void *Data)
@@ -2014,6 +2019,10 @@ void dynamic_parameter_base::updateAccessor(
20142019
impl->updateAccessor(Acc);
20152020
}
20162021

2022+
void dynamic_parameter_base::updateWorkGroupMem(size_t BufferSize) {
2023+
impl->updateWorkGroupMem(BufferSize);
2024+
}
2025+
20172026
void dynamic_parameter_impl::updateValue(const raw_kernel_arg *NewRawValue,
20182027
size_t Size) {
20192028
// Number of bytes is taken from member of raw_kernel_arg object rather
@@ -2069,6 +2078,39 @@ void dynamic_parameter_impl::updateAccessor(
20692078
sizeof(sycl::detail::AccessorBaseHost));
20702079
}
20712080

2081+
void dynamic_parameter_impl::updateWorkGroupMem(size_t BufferSize) {
2082+
for (auto &[NodeWeak, ArgIndex] : MNodes) {
2083+
auto NodeShared = NodeWeak.lock();
2084+
if (NodeShared) {
2085+
dynamic_parameter_impl::updateCGWorkGroupMem(NodeShared->MCommandGroup,
2086+
ArgIndex, BufferSize);
2087+
}
2088+
}
2089+
2090+
for (auto &DynCGInfo : MDynCGs) {
2091+
auto DynCG = DynCGInfo.DynCG.lock();
2092+
if (DynCG) {
2093+
auto &CG = DynCG->MCommandGroups[DynCGInfo.CGIndex];
2094+
dynamic_parameter_impl::updateCGWorkGroupMem(CG, DynCGInfo.ArgIndex,
2095+
BufferSize);
2096+
}
2097+
}
2098+
}
2099+
2100+
void dynamic_parameter_impl::updateCGWorkGroupMem(
2101+
std::shared_ptr<sycl::detail::CG> CG, int ArgIndex, size_t BufferSize) {
2102+
2103+
auto &Args = static_cast<sycl::detail::CGExecKernel *>(CG.get())->MArgs;
2104+
for (auto &Arg : Args) {
2105+
if (Arg.MIndex != ArgIndex) {
2106+
continue;
2107+
}
2108+
assert(Arg.MType == sycl::detail::kernel_param_kind_t::kind_std_layout);
2109+
Arg.MSize = BufferSize;
2110+
break;
2111+
}
2112+
}
2113+
20722114
void dynamic_parameter_impl::updateCGArgValue(
20732115
std::shared_ptr<sycl::detail::CG> CG, int ArgIndex, const void *NewValue,
20742116
size_t Size) {

sycl/source/detail/graph_impl.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,6 +1479,9 @@ class exec_graph_impl {
14791479

14801480
class dynamic_parameter_impl {
14811481
public:
1482+
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl)
1483+
: MGraph(GraphImpl) {}
1484+
14821485
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl,
14831486
size_t ParamSize, const void *Data)
14841487
: MGraph(GraphImpl), MValueStorage(ParamSize),
@@ -1546,6 +1549,22 @@ class dynamic_parameter_impl {
15461549
/// @param Acc The new accessor value
15471550
void updateAccessor(const sycl::detail::AccessorBaseHost *Acc);
15481551

1552+
/// Update the internal value of this dynamic parameter as well as the value
1553+
/// of this parameter in all registered nodes and dynamic CGs. Should only be
1554+
/// called for dynamic_work_group_memory arguments parameter.
1555+
/// @param BufferSize The total size in bytes of the new work_group_memory
1556+
/// array
1557+
void updateWorkGroupMem(size_t BufferSize);
1558+
1559+
/// Static helper function for updating command-group
1560+
/// dynamic_work_group_memory arguments.
1561+
/// @param CG The command-group to update the argument information for.
1562+
/// @param ArgIndex The argument index to update.
1563+
/// @param BufferSize The total size in bytes of the new work_group_memory
1564+
/// array
1565+
static void updateCGWorkGroupMem(std::shared_ptr<sycl::detail::CG> CG,
1566+
int ArgIndex, size_t BufferSize);
1567+
15491568
/// Static helper function for updating command-group value arguments.
15501569
/// @param CG The command-group to update the argument information for.
15511570
/// @param ArgIndex The argument index to update.

sycl/source/detail/handler_impl.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@ class handler_impl {
199199
std::vector<std::shared_ptr<detail::work_group_memory_impl>>
200200
MWorkGroupMemoryObjects;
201201

202+
/// List of dynamic work group memory objects associated with this handler
203+
std::vector<std::shared_ptr<
204+
ext::oneapi::experimental::detail::dynamic_work_group_memory_base>>
205+
MDynWorkGroupMemoryParams;
206+
202207
/// Potential event mode for the result event of the command.
203208
ext::oneapi::experimental::event_mode_enum MEventMode =
204209
ext::oneapi::experimental::event_mode_enum::none;

sycl/source/detail/jit_compiler.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ translateArgType(kernel_param_kind_t Kind) {
180180
return PK::Stream;
181181
case kind::kind_work_group_memory:
182182
return PK::WorkGroupMemory;
183+
case kind::kind_dynamic_work_group_memory:
184+
return PK::DynamicWorkGroupMemory;
183185
case kind::kind_invalid:
184186
return PK::Invalid;
185187
}

sycl/source/detail/scheduler/commands.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2319,6 +2319,8 @@ void SetArgBasedOnType(
23192319
const ContextImplPtr &ContextImpl, detail::ArgDesc &Arg,
23202320
size_t NextTrueIndex) {
23212321
switch (Arg.MType) {
2322+
case kernel_param_kind_t::kind_dynamic_work_group_memory:
2323+
break;
23222324
case kernel_param_kind_t::kind_work_group_memory:
23232325
break;
23242326
case kernel_param_kind_t::kind_stream:

0 commit comments

Comments
 (0)