Skip to content

Commit 7fb552a

Browse files
Implemented dynamic_work_group_memory with lambdas
1 parent 21b195c commit 7fb552a

File tree

5 files changed

+63
-52
lines changed

5 files changed

+63
-52
lines changed

sycl/include/sycl/ext/oneapi/experimental/graph.hpp

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,6 @@ class command_graph<graph_state::executable>
488488
namespace detail {
489489
class __SYCL_EXPORT dynamic_parameter_base {
490490
public:
491-
dynamic_parameter_base() = default;
492491
dynamic_parameter_base(
493492
sycl::ext::oneapi::experimental::command_graph<graph_state::modifiable>
494493
Graph);
@@ -526,37 +525,55 @@ class __SYCL_EXPORT dynamic_parameter_base {
526525
sycl::detail::getSyclObjImpl(const Obj &SyclObject);
527526
};
528527

529-
} // namespace detail
530-
531-
template <typename T> struct is_unbounded_array : std::false_type {};
532-
533-
template <typename T> struct is_unbounded_array<T[]> : std::true_type {};
528+
class dynamic_work_group_memory_base
529+
#ifndef __SYCL_DEVICE_ONLY__
530+
: public dynamic_parameter_base
531+
#endif
532+
{
533+
public:
534+
dynamic_work_group_memory_base() = default;
535+
dynamic_work_group_memory_base(
536+
experimental::command_graph<graph_state::modifiable> Graph, size_t Size)
537+
:
538+
#ifndef __SYCL_DEVICE_ONLY__
539+
dynamic_parameter_base(Graph),
540+
#endif
541+
BufferSize(Size) {
542+
}
534543

535-
template <typename T>
536-
inline constexpr bool is_unbounded_array_v = is_unbounded_array<T>::value;
544+
private:
545+
#ifdef __SYCL_DEVICE_ONLY__
546+
[[maybe_unused]] char padding[sizeof(dynamic_parameter_base)];
547+
#endif
548+
size_t BufferSize{};
549+
friend class sycl::handler;
550+
};
551+
} // namespace detail
537552

538553
template <typename DataT,
539-
typename = std::enable_if_t<is_unbounded_array_v<DataT>>>
540-
554+
typename = std::enable_if_t<detail::is_unbounded_array_v<DataT>>>
541555
class __SYCL_SPECIAL_CLASS
542556
__SYCL_TYPE(dynamic_work_group_memory) dynamic_work_group_memory
557+
: public detail::dynamic_work_group_memory_base {
558+
private:
559+
work_group_memory<DataT> WorkGroupMem;
560+
561+
using value_type = std::remove_all_extents_t<DataT>;
562+
using decoratedPtr = typename sycl::detail::DecoratedType<
563+
value_type, access::address_space::local_space>::type *;
564+
543565
#ifdef __SYCL_DEVICE_ONLY__
544-
: detail::dynamic_parameter_base
545-
#else
546-
: public detail::dynamic_parameter_base
566+
void __init(decoratedPtr Ptr) { this->WorkGroupMem.__init(Ptr); }
547567
#endif
548-
{
568+
549569
public:
550-
dynamic_work_group_memory() = default;
551570
/// Constructs a new dynamic_work_group_memory object.
552571
/// @param Graph The graph associated with this object.
553572
/// @param Num Number of elements in the unbounded array DataT.
554573
dynamic_work_group_memory(
555-
experimental::command_graph<graph_state::modifiable> Graph, size_t Num) {
556-
auto &WorkGroupMemImpl =
557-
static_cast<detail::work_group_memory_impl &>(WorkGroupMem);
558-
WorkGroupMemImpl.buffer_size = Num * sizeof(std::remove_extent_t<DataT>);
559-
}
574+
experimental::command_graph<graph_state::modifiable> Graph, size_t Num)
575+
: detail::dynamic_work_group_memory_base(
576+
Graph, Num * sizeof(std::remove_extent_t<DataT>)) {}
560577

561578
/// Updates this dynamic_work_group_memory and all registered nodes with a new
562579
/// number of elements.
@@ -567,19 +584,15 @@ __SYCL_TYPE(dynamic_work_group_memory) dynamic_work_group_memory
567584
Num * sizeof(std::remove_extent_t<DataT>));
568585
#endif
569586
}
570-
571-
const work_group_memory<DataT> &get() const { return WorkGroupMem; }
572-
573-
private:
574-
work_group_memory<DataT> WorkGroupMem;
575-
#ifdef __SYCL_DEVICE_ONLY__
576-
// [[maybe_unused]] char padding[sizeof(detail::dynamic_parameter_base)];
577-
using value_type = std::remove_all_extents_t<DataT>;
578-
using decoratedPtr = typename sycl::detail::DecoratedType<
579-
value_type, access::address_space::local_space>::type *;
580-
581-
void __init(decoratedPtr Ptr) { this->WorkGroupMem.__init(Ptr); }
582-
#endif
587+
work_group_memory<DataT> get() const { return WorkGroupMem; }
588+
589+
// Frontend requires special types to have a default constructor in order to
590+
// have a uniform way of initializing an object of special type to then call
591+
// the __init method on it. This is purely an implementation detail and not
592+
// part of the spec.
593+
// TODO: Revisit this once https://github.com/intel/llvm/issues/16061 is
594+
// closed.
595+
dynamic_work_group_memory() = default;
583596
};
584597

585598
template <typename ValueT>

sycl/include/sycl/ext/oneapi/experimental/work_group_memory.hpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,6 @@ namespace sycl {
2020
inline namespace _V1 {
2121
class handler;
2222

23-
namespace ext {
24-
namespace oneapi {
25-
namespace experimental {
26-
template <typename DataT, typename Enable> class dynamic_work_group_memory;
27-
}
28-
} // namespace oneapi
29-
} // namespace ext
30-
3123
namespace detail {
3224
template <typename T> struct is_unbounded_array : std::false_type {};
3325

@@ -47,13 +39,12 @@ class work_group_memory_impl {
4739
private:
4840
size_t buffer_size;
4941
friend class sycl::handler;
50-
51-
template <typename DataT, typename Enable>
52-
friend class sycl::ext::oneapi::experimental::dynamic_work_group_memory;
5342
};
5443

5544
} // namespace detail
5645
namespace ext::oneapi::experimental {
46+
// Forward decleration to be able to befriend dynamic_work_group_memory.
47+
// template <typename, typename> class dynamic_work_group_memory;
5748

5849
struct indeterminate_t {};
5950
inline constexpr indeterminate_t indeterminate;
@@ -127,9 +118,8 @@ class __SYCL_SPECIAL_CLASS __SYCL_TYPE(work_group_memory) work_group_memory
127118
friend class sycl::handler; // needed in order for handler class to be aware
128119
// of the private inheritance with
129120
// work_group_memory_impl as base class
130-
//
131-
template <typename T, typename Enable>
132-
friend class sycl::ext::oneapi::experimental::dynamic_work_group_memory;
121+
122+
template <typename, typename> friend class dynamic_work_group_memory;
133123

134124
decoratedPtr ptr = nullptr;
135125
};

sycl/include/sycl/handler.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,14 @@ class pipe;
153153

154154
namespace ext ::oneapi ::experimental {
155155
template <typename, typename> class work_group_memory;
156+
template <typename, typename> class dynamic_work_group_memory;
156157
struct image_descriptor;
157158
} // namespace ext::oneapi::experimental
158159

159160
namespace ext::oneapi::experimental::detail {
160161
class graph_impl;
162+
class dynamic_work_group_memory_base;
163+
class dynamic_parameter_base;
161164
} // namespace ext::oneapi::experimental::detail
162165
namespace detail {
163166

sycl/source/detail/graph_impl.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,7 +2033,7 @@ void dynamic_parameter_impl::updateWorkGroupMem(size_t BufferSize) {
20332033
for (auto &DynCGInfo : MDynCGs) {
20342034
auto DynCG = DynCGInfo.DynCG.lock();
20352035
if (DynCG) {
2036-
auto &CG = DynCG->MKernels[DynCGInfo.CGIndex];
2036+
auto &CG = DynCG->MCommandGroups[DynCGInfo.CGIndex];
20372037
dynamic_parameter_impl::updateCGWorkGroupMem(CG, DynCGInfo.ArgIndex,
20382038
BufferSize);
20392039
}
@@ -2048,8 +2048,7 @@ void dynamic_parameter_impl::updateCGWorkGroupMem(
20482048
if (Arg.MIndex != ArgIndex) {
20492049
continue;
20502050
}
2051-
assert(Arg.MType ==
2052-
sycl::detail::kernel_param_kind_t::kind_dynamic_work_group_memory);
2051+
assert(Arg.MType == sycl::detail::kernel_param_kind_t::kind_std_layout);
20532052
Arg.MSize = BufferSize;
20542053
break;
20552054
}

sycl/source/handler.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "sycl/detail/helpers.hpp"
10+
#include "sycl/ext/oneapi/experimental/graph.hpp"
1011
#include "ur_api.h"
1112
#include <algorithm>
1213

@@ -1004,10 +1005,15 @@ void handler::processArg(void *Ptr, const detail::kernel_param_kind_t &Kind,
10041005
auto *DynBase = static_cast<
10051006
ext::oneapi::experimental::detail::dynamic_parameter_base *>(Ptr);
10061007

1008+
auto *DynWorkGroupBase = static_cast<
1009+
ext::oneapi::experimental::detail::dynamic_work_group_memory_base *>(
1010+
Ptr);
1011+
10071012
registerDynamicParameter(*DynBase, Index + IndexShift);
10081013

1009-
Ptr = static_cast<void *>(++DynBase);
1010-
[[fallthrough]];
1014+
addArg(kernel_param_kind_t::kind_std_layout, nullptr,
1015+
DynWorkGroupBase->BufferSize, Index + IndexShift);
1016+
break;
10111017
}
10121018
case kernel_param_kind_t::kind_work_group_memory: {
10131019
addArg(kernel_param_kind_t::kind_std_layout, nullptr,

0 commit comments

Comments
 (0)