Skip to content

Commit 70fee11

Browse files
Added support for free function kernels
1 parent a21a741 commit 70fee11

File tree

18 files changed

+562
-58
lines changed

18 files changed

+562
-58
lines changed

clang/include/clang/Sema/SemaSYCL.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ class SemaSYCL : public SemaBase {
667667
// Used to check whether the function represented by FD is a SYCL
668668
// free function kernel or not.
669669
bool isFreeFunction(const FunctionDecl *FD);
670-
670+
671671
StmtResult BuildSYCLKernelCallStmt(FunctionDecl *FD, CompoundStmt *Body);
672672
};
673673

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,7 +2090,9 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
20902090
}
20912091

20922092
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
2093-
if (!SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) {
2093+
if (!SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory) &&
2094+
!SemaSYCL::isSyclType(ParamTy,
2095+
SYCLTypeAttr::dynamic_work_group_memory)) {
20942096
Diag.Report(PD->getLocation(), diag::err_bad_kernel_param_type)
20952097
<< ParamTy;
20962098
IsInvalid = true;
@@ -2246,7 +2248,8 @@ class SyclKernelUnionChecker : public SyclKernelFieldHandler {
22462248
}
22472249

22482250
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
2249-
if (!SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory))
2251+
if (!SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory) &&
2252+
!SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::dynamic_work_group_memory))
22502253
unsupportedFreeFunctionParamType(); // TODO
22512254
return true;
22522255
}
@@ -3032,7 +3035,9 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
30323035
}
30333036

30343037
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
3035-
if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) {
3038+
if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory) ||
3039+
SemaSYCL::isSyclType(ParamTy,
3040+
SYCLTypeAttr::dynamic_work_group_memory)) {
30363041
const auto *RecordDecl = ParamTy->getAsCXXRecordDecl();
30373042
assert(RecordDecl && "The type must be a RecordDecl");
30383043
CXXMethodDecl *InitMethod = getMethodByName(RecordDecl, InitMethodName);
@@ -4544,7 +4549,9 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
45444549
// TODO: Revisit this approach once https://github.com/intel/llvm/issues/16061
45454550
// is closed.
45464551
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
4547-
if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory)) {
4552+
if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory) ||
4553+
SemaSYCL::isSyclType(ParamTy,
4554+
SYCLTypeAttr::dynamic_work_group_memory)) {
45484555
const auto *RecordDecl = ParamTy->getAsCXXRecordDecl();
45494556
AccessSpecifier DefaultConstructorAccess;
45504557
auto DefaultConstructor =
@@ -4823,6 +4830,10 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
48234830
} else if (SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::work_group_memory)) {
48244831
addParam(FieldTy, SYCLIntegrationHeader::kind_work_group_memory,
48254832
offsetOf(RD, BC.getType()->getAsCXXRecordDecl()));
4833+
} else if (SemaSYCL::isSyclType(FieldTy,
4834+
SYCLTypeAttr::dynamic_work_group_memory)) {
4835+
addParam(FieldTy, SYCLIntegrationHeader::kind_dynamic_work_group_memory,
4836+
offsetOf(RD, BC.getType()->getAsCXXRecordDecl()));
48264837
}
48274838
return true;
48284839
}
@@ -4847,9 +4858,9 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
48474858
addParam(FieldTy, SYCLIntegrationHeader::kind_work_group_memory,
48484859
offsetOf(FD, FieldTy));
48494860
} else if (SemaSYCL::isSyclType(FieldTy,
4850-
SYCLTypeAttr::dynamic_work_group_memory)) {
4861+
SYCLTypeAttr::dynamic_work_group_memory)) {
48514862
addParam(FieldTy, SYCLIntegrationHeader::kind_dynamic_work_group_memory,
4852-
offsetOf(FD, FieldTy));
4863+
offsetOf(FD, FieldTy));
48534864
} else if (SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::sampler) ||
48544865
SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::annotated_ptr) ||
48554866
SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::annotated_arg)) {
@@ -4874,6 +4885,10 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
48744885
bool handleSyclSpecialType(ParmVarDecl *PD, QualType ParamTy) final {
48754886
if (SemaSYCL::isSyclType(ParamTy, SYCLTypeAttr::work_group_memory))
48764887
addParam(PD, ParamTy, SYCLIntegrationHeader::kind_work_group_memory);
4888+
else if (SemaSYCL::isSyclType(ParamTy,
4889+
SYCLTypeAttr::dynamic_work_group_memory))
4890+
addParam(PD, ParamTy,
4891+
SYCLIntegrationHeader::kind_dynamic_work_group_memory);
48774892
else
48784893
unsupportedFreeFunctionParamType(); // TODO
48794894
return true;

sycl/include/sycl/ext/oneapi/experimental/graph.hpp

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
#pragma once
1010

11-
#include <cstddef>
1211
#include <sycl/accessor.hpp> // for detail::AccessorBaseHost
1312
#include <sycl/context.hpp> // for context
1413
#include <sycl/detail/export.hpp> // for __SYCL_EXPORT
@@ -21,7 +20,8 @@
2120
#include <sycl/device.hpp> // for device
2221
#include <sycl/ext/oneapi/experimental/detail/properties/graph_properties.hpp> // for graph properties classes
2322
#include <sycl/ext/oneapi/experimental/work_group_memory.hpp> // for dynamic_work_group_memory
24-
#include <sycl/nd_range.hpp> // for range, nd_range
23+
#include <sycl/ext/oneapi/properties/properties.hpp> // for empty_properties_t
24+
#include <sycl/nd_range.hpp> // for range, nd_range
2525
#include <sycl/properties/property_traits.hpp> // for is_property, is_property_of
2626
#include <sycl/property_list.hpp> // for property_list
2727

@@ -49,6 +49,7 @@ enum class graph_state {
4949
// Forward declare ext::oneapi::experimental classes
5050
template <graph_state State> class command_graph;
5151
class raw_kernel_arg;
52+
template <typename, typename> class work_group_memory;
5253

5354
namespace detail {
5455
// List of sycl features and extensions which are not supported by graphs. Used
@@ -503,6 +504,7 @@ class command_graph<graph_state::executable>
503504
namespace detail {
504505
class __SYCL_EXPORT dynamic_parameter_base {
505506
public:
507+
dynamic_parameter_base() = default;
506508
dynamic_parameter_base(
507509
sycl::ext::oneapi::experimental::command_graph<graph_state::modifiable>
508510
Graph);
@@ -548,12 +550,13 @@ class dynamic_work_group_memory_base
548550
public:
549551
dynamic_work_group_memory_base() = default;
550552
dynamic_work_group_memory_base(
551-
experimental::command_graph<graph_state::modifiable> Graph, size_t Size)
552-
:
553+
[[maybe_unused]] experimental::command_graph<graph_state::modifiable>
554+
Graph,
555+
[[maybe_unused]] size_t Size)
553556
#ifndef __SYCL_DEVICE_ONLY__
554-
dynamic_parameter_base(Graph),
557+
: dynamic_parameter_base(Graph), BufferSize(Size)
555558
#endif
556-
BufferSize(Size) {
559+
{
557560
}
558561

559562
private:
@@ -565,23 +568,23 @@ class dynamic_work_group_memory_base
565568
};
566569
} // namespace detail
567570

568-
template <typename DataT,
569-
typename = std::enable_if_t<detail::is_unbounded_array_v<DataT>>>
571+
template <typename DataT, typename PropertyListT = empty_properties_t>
570572
class __SYCL_SPECIAL_CLASS
571573
__SYCL_TYPE(dynamic_work_group_memory) dynamic_work_group_memory
572574
: public detail::dynamic_work_group_memory_base {
573-
private:
574-
work_group_memory<DataT> WorkGroupMem;
575-
576-
#ifdef __SYCL_DEVICE_ONLY__
577-
using value_type = std::remove_all_extents_t<DataT>;
578-
using decoratedPtr = typename sycl::detail::DecoratedType<
579-
value_type, access::address_space::local_space>::type *;
575+
public:
576+
// Check that DataT is an unbounded array type.
577+
static_assert(std::is_array_v<DataT> && std::extent_v<DataT, 0> == 0);
578+
static_assert(std::is_same_v<PropertyListT, empty_properties_t>);
580579

581-
void __init(decoratedPtr Ptr) { this->WorkGroupMem.__init(Ptr); }
582-
#endif
580+
// Frontend requires special types to have a default constructor in order to
581+
// have a uniform way of initializing an object of special type to then call
582+
// the __init method on it. This is purely an implementation detail and not
583+
// part of the spec.
584+
// TODO: Revisit this once https://github.com/intel/llvm/issues/16061 is
585+
// closed.
586+
dynamic_work_group_memory() = default;
583587

584-
public:
585588
/// Constructs a new dynamic_work_group_memory object.
586589
/// @param Graph The graph associated with this object.
587590
/// @param Num Number of elements in the unbounded array DataT.
@@ -590,24 +593,35 @@ __SYCL_TYPE(dynamic_work_group_memory) dynamic_work_group_memory
590593
: detail::dynamic_work_group_memory_base(
591594
Graph, Num * sizeof(std::remove_extent_t<DataT>)) {}
592595

596+
work_group_memory<DataT, PropertyListT> get() const {
597+
#ifndef __SYCL_DEVICE_ONLY__
598+
throw sycl::exception(sycl::make_error_code(errc::invalid),
599+
"Error: dynamic_work_group_memory::get() can be only "
600+
"called on the device!");
601+
#endif
602+
return WorkGroupMem;
603+
}
604+
593605
/// Updates on the host this dynamic_work_group_memory and all registered
594606
/// nodes with a new buffer size.
595607
/// @param Num The new number of elements in the unbounded array.
596-
void update(size_t Num) {
608+
void update([[maybe_unused]] size_t Num) {
597609
#ifndef __SYCL_DEVICE_ONLY__
598610
detail::dynamic_parameter_base::updateWorkGroupMem(
599611
Num * sizeof(std::remove_extent_t<DataT>));
600612
#endif
601613
}
602-
work_group_memory<DataT> get() const { return WorkGroupMem; }
603614

604-
// Frontend requires special types to have a default constructor in order to
605-
// have a uniform way of initializing an object of special type to then call
606-
// the __init method on it. This is purely an implementation detail and not
607-
// part of the spec.
608-
// TODO: Revisit this once https://github.com/intel/llvm/issues/16061 is
609-
// closed.
610-
dynamic_work_group_memory() = default;
615+
private:
616+
work_group_memory<DataT, PropertyListT> WorkGroupMem;
617+
618+
#ifdef __SYCL_DEVICE_ONLY__
619+
using value_type = std::remove_all_extents_t<DataT>;
620+
using decoratedPtr = typename sycl::detail::DecoratedType<
621+
value_type, access::address_space::local_space>::type *;
622+
623+
void __init(decoratedPtr Ptr) { this->WorkGroupMem.__init(Ptr); }
624+
#endif
611625
};
612626

613627
template <typename ValueT>
@@ -684,4 +698,14 @@ struct hash<sycl::ext::oneapi::experimental::dynamic_parameter<ValueT>> {
684698
return std::hash<decltype(ID)>()(ID);
685699
}
686700
};
701+
702+
template <typename DataT>
703+
struct hash<sycl::ext::oneapi::experimental::dynamic_work_group_memory<DataT>> {
704+
size_t operator()(
705+
const sycl::ext::oneapi::experimental::dynamic_work_group_memory<DataT>
706+
&DynWorkGroupMem) const {
707+
auto ID = sycl::detail::getSyclObjImpl(DynWorkGroupMem)->getID();
708+
return std::hash<decltype(ID)>()(ID);
709+
}
710+
};
687711
} // namespace std

sycl/include/sycl/ext/oneapi/experimental/work_group_memory.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
#pragma once
99

10-
#include "sycl/ext/oneapi/experimental/graph.hpp"
1110
#include <sycl/access/access.hpp>
1211
#include <sycl/detail/defines.hpp>
12+
#include <sycl/ext/oneapi/experimental/graph.hpp>
1313
#include <sycl/ext/oneapi/properties/properties.hpp>
1414
#include <sycl/multi_ptr.hpp>
1515

sycl/include/sycl/handler.hpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ __SYCL_EXPORT void *async_malloc_from_pool(sycl::handler &h, size_t size,
161161

162162
namespace ext::oneapi::experimental::detail {
163163
class graph_impl;
164-
class dynamic_work_group_memory_base;
165164
class dynamic_parameter_base;
165+
class dynamic_work_group_memory_base;
166166
} // namespace ext::oneapi::experimental::detail
167167
namespace detail {
168168

@@ -687,7 +687,7 @@ class __SYCL_EXPORT handler {
687687
void
688688
setArgHelper(int ArgIndex,
689689
ext::oneapi::experimental::detail::dynamic_work_group_memory_base
690-
&DynWorkGroupMemParam);
690+
&DynWorkGroupBase);
691691

692692
// setArgHelper for the raw_kernel_arg extension type.
693693
void setArgHelper(int ArgIndex,
@@ -1889,13 +1889,15 @@ class __SYCL_EXPORT handler {
18891889
}
18901890

18911891
// set_arg for graph dynamic_work_group_memory
1892-
template <typename DataT>
1893-
void set_arg(int argIndex,
1894-
ext::oneapi::experimental::dynamic_work_group_memory<DataT>
1895-
&dynWorkGroupMem) {
1892+
template <typename DataT, typename PropertyListT =
1893+
ext::oneapi::experimental::empty_properties_t>
1894+
void set_arg(
1895+
int argIndex,
1896+
ext::oneapi::experimental::dynamic_work_group_memory<DataT, PropertyListT>
1897+
&dynWorkGroupMem) {
18961898
ext::oneapi::experimental::detail::dynamic_work_group_memory_base
1897-
&dynWorkGroupMemImpl = dynWorkGroupMem;
1898-
setArgHelper(argIndex, dynWorkGroupMemImpl);
1899+
&dynWorkGroupBase = dynWorkGroupMem;
1900+
setArgHelper(argIndex, dynWorkGroupBase);
18991901
}
19001902

19011903
// set_arg for the raw_kernel_arg extension type.

sycl/source/detail/graph_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1588,7 +1588,7 @@ class dynamic_parameter_impl {
15881588
// Dynamic command-groups which will be updated
15891589
std::vector<DynamicCGInfo> MDynCGs;
15901590

1591-
std::shared_ptr<graph_impl> MGraph;
1591+
std::weak_ptr<graph_impl> MGraph;
15921592
std::vector<std::byte> MValueStorage;
15931593

15941594
private:

sycl/source/detail/handler_impl.hpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,6 @@ class handler_impl {
199199
std::vector<std::shared_ptr<detail::work_group_memory_impl>>
200200
MWorkGroupMemoryObjects;
201201

202-
/// List of dynamic work group memory objects associated with this handler
203-
std::vector<std::shared_ptr<
204-
ext::oneapi::experimental::detail::dynamic_work_group_memory_base>>
205-
MDynWorkGroupMemoryParams;
206-
207202
/// Potential event mode for the result event of the command.
208203
ext::oneapi::experimental::event_mode_enum MEventMode =
209204
ext::oneapi::experimental::event_mode_enum::none;

sycl/source/handler.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "sycl/detail/helpers.hpp"
10-
#include "sycl/ext/oneapi/experimental/graph.hpp"
1110
#include "ur_api.h"
1211
#include <algorithm>
1312

@@ -34,6 +33,7 @@
3433
#include <sycl/info/info_desc.hpp>
3534
#include <sycl/stream.hpp>
3635

36+
#include "sycl/ext/oneapi/experimental/graph.hpp"
3737
#include <sycl/ext/oneapi/bindless_images_memory.hpp>
3838
#include <sycl/ext/oneapi/experimental/work_group_memory.hpp>
3939
#include <sycl/ext/oneapi/memcpy2d.hpp>
@@ -1058,18 +1058,14 @@ void handler::setArgHelper(int ArgIndex, detail::work_group_memory_impl &Arg) {
10581058
void handler::setArgHelper(
10591059
int ArgIndex,
10601060
ext::oneapi::experimental::detail::dynamic_work_group_memory_base
1061-
&DynWorkGroupMemParam) {
1061+
&DynWorkGroupBase) {
10621062

1063-
impl->MDynWorkGroupMemoryParams.push_back(
1064-
std::make_shared<
1065-
ext::oneapi::experimental::detail::dynamic_work_group_memory_base>(
1066-
DynWorkGroupMemParam));
10671063
addArg(detail::kernel_param_kind_t::kind_dynamic_work_group_memory,
1068-
impl->MDynWorkGroupMemoryParams.back().get(), 0, ArgIndex);
1064+
&DynWorkGroupBase, 0, ArgIndex);
10691065

10701066
// Register the dynamic parameter with the handler for later association
10711067
// with the node being added
1072-
registerDynamicParameter(DynWorkGroupMemParam, ArgIndex);
1068+
registerDynamicParameter(DynWorkGroupBase, ArgIndex);
10731069
}
10741070

10751071
// The argument can take up more space to store additional information about
@@ -2102,7 +2098,7 @@ void handler::registerDynamicParameter(
21022098
}
21032099

21042100
auto Paraimpl = detail::getSyclObjImpl(DynamicParamBase);
2105-
if (Paraimpl->MGraph != this->impl->MGraph) {
2101+
if (Paraimpl->MGraph.lock() != this->impl->MGraph) {
21062102
throw sycl::exception(
21072103
make_error_code(errc::invalid),
21082104
"Cannot use a Dynamic Parameter with a node associated with a graph "

0 commit comments

Comments
 (0)