Skip to content

[SYCL] Enable host optimization of work-item free functions #2967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,12 @@ class SYCLIntegrationHeader {
/// Registers a specialization constant to emit info for it into the header.
void addSpecConstant(StringRef IDName, QualType IDType);

/// Notes that this_item is called within the kernel.
/// Note which free functions (this_id, this_item, etc) are called within the
/// kernel
void setCallsThisId(bool B);
void setCallsThisItem(bool B);
void setCallsThisNDItem(bool B);
void setCallsThisGroup(bool B);

private:
// Kernel actual parameter descriptor.
Expand All @@ -366,6 +370,15 @@ class SYCLIntegrationHeader {
KernelParamDesc() = default;
};

// there are four free functions the kernel may call (this_id, this_item,
// this_nd_item, this_group)
struct KernelCallsSYCLFreeFunction {
bool CallsThisId;
bool CallsThisItem;
bool CallsThisNDItem;
bool CallsThisGroup;
};

// Kernel invocation descriptor
struct KernelDesc {
/// Kernel name.
Expand All @@ -385,8 +398,9 @@ class SYCLIntegrationHeader {
/// Descriptor of kernel actual parameters.
SmallVector<KernelParamDesc, 8> Params;

// Whether kernel calls this_item()
bool CallsThisItem;
// Whether kernel calls any of the SYCL free functions (this_item(),
// this_id(), etc)
KernelCallsSYCLFreeFunction FreeFunctionCalls;

KernelDesc() = default;
};
Expand Down
46 changes: 42 additions & 4 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2721,11 +2721,24 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
if (!Visited.insert(FD).second)
continue; // We've already seen this Decl

// Check whether this call is to sycl::this_item().
// Check whether this call is to free functions (sycl::this_item(),
// this_id, etc.).
if (Util::isSyclFunction(FD, "this_id")) {
Header.setCallsThisId(true);
return;
}
if (Util::isSyclFunction(FD, "this_item")) {
Header.setCallsThisItem(true);
return;
}
if (Util::isSyclFunction(FD, "this_nd_item")) {
Header.setCallsThisNDItem(true);
return;
}
if (Util::isSyclFunction(FD, "this_group")) {
Header.setCallsThisGroup(true);
return;
}

CallGraphNode *N = SYCLCG.getNode(FD);
if (!N)
Expand Down Expand Up @@ -3920,7 +3933,14 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
<< "; }\n";
O << " __SYCL_DLL_LOCAL\n";
O << " static constexpr bool callsThisItem() { return ";
O << K.CallsThisItem << "; }\n";
O << K.FreeFunctionCalls.CallsThisItem << "; }\n";
O << " __SYCL_DLL_LOCAL\n";
O << " static constexpr bool callsAnyThisFreeFunction() { return ";
O << (K.FreeFunctionCalls.CallsThisId ||
K.FreeFunctionCalls.CallsThisItem ||
K.FreeFunctionCalls.CallsThisNDItem ||
K.FreeFunctionCalls.CallsThisGroup)
<< "; }\n";
Comment on lines +3939 to +3943
Copy link
Contributor

@Fznamznon Fznamznon Dec 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels that if the final information in the integration header doesn't say which exactly free function is called, we don't need four different flags and four identical handlers (I mean SYCLIntegrationHeader::setCallsThisItem and others here) in front-end either.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the other hand I guess the host runtime could be optimized if we knows exactly which one is used...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@intel/llvm-reviewers-runtime , WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pre-existing code unrelated to this PR needs to know if this_item is called and that isn't concerned about the other free functions. The optimization in this PR just needs to know if any of them are called.
My thoughts are that it's best to make a proper record of each free function now. Recording only "this_item" and "anything" seems sloppy.
On the API side, though, the functions match our current needs. The pre-existing callsThisItem() and callsAnyThisFreeFunction() are the two affordances, but can be easily expanded in the future if required.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you expect it to be required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know of anything pending. But I'd wager at even odds on needing to know about this_nd_item or this_group usage in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright thanks! I'm ok with the FE changes.

O << "};\n";
CurStart += N;
}
Expand Down Expand Up @@ -3979,10 +3999,28 @@ void SYCLIntegrationHeader::addSpecConstant(StringRef IDName, QualType IDType) {
SpecConsts.emplace_back(std::make_pair(IDType, IDName.str()));
}

void SYCLIntegrationHeader::setCallsThisId(bool B) {
KernelDesc *K = getCurKernelDesc();
assert(K && "no kernel");
K->FreeFunctionCalls.CallsThisId = B;
}

void SYCLIntegrationHeader::setCallsThisItem(bool B) {
KernelDesc *K = getCurKernelDesc();
assert(K && "no kernels");
K->CallsThisItem = B;
assert(K && "no kernel");
K->FreeFunctionCalls.CallsThisItem = B;
}

void SYCLIntegrationHeader::setCallsThisNDItem(bool B) {
KernelDesc *K = getCurKernelDesc();
assert(K && "no kernel");
K->FreeFunctionCalls.CallsThisNDItem = B;
}

void SYCLIntegrationHeader::setCallsThisGroup(bool B) {
KernelDesc *K = getCurKernelDesc();
assert(K && "no kernel");
K->FreeFunctionCalls.CallsThisGroup = B;
}

SYCLIntegrationHeader::SYCLIntegrationHeader(DiagnosticsEngine &_Diag,
Expand Down
3 changes: 3 additions & 0 deletions clang/test/CodeGenSYCL/Inputs/sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ template <int dim> struct item {
template <int Dims> item<Dims>
this_item() { return item<Dims>{}; }

template <int Dims> id<Dims>
this_id() { return id<Dims>{}; }

template <int dim>
struct range {
template <typename... T>
Expand Down
4 changes: 2 additions & 2 deletions clang/test/CodeGenSYCL/kernel-by-reference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ int simple_add(int i) {
int main() {
queue q;
#if defined(SYCL2020)
// expected-warning@Inputs/sycl.hpp:298 {{Passing kernel functions by value is deprecated in SYCL 2020}}
// expected-warning@Inputs/sycl.hpp:301 {{Passing kernel functions by value is deprecated in SYCL 2020}}
// expected-note@+3 {{in instantiation of function template specialization}}
#endif
q.submit([&](handler &h) {
h.single_task_2017<class sycl2017>([]() { simple_add(10); });
});

#if defined(SYCL2017)
// expected-warning@Inputs/sycl.hpp:293 {{Passing of kernel functions by reference is a SYCL 2020 extension}}
// expected-warning@Inputs/sycl.hpp:296 {{Passing of kernel functions by reference is a SYCL 2020 extension}}
// expected-note@+3 {{in instantiation of function template specialization}}
#endif
q.submit([&](handler &h) {
Expand Down
31 changes: 30 additions & 1 deletion clang/test/CodeGenSYCL/parallel_for_this_item.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3GNU",
// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3EMU",
// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3OWL",
// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3RAT"
// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3RAT",
// CHECK-NEXT: "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3FOX"
// CHECK-NEXT: };

// CHECK:template <> struct KernelInfo<class GNU> {
Expand All @@ -29,6 +30,8 @@
// CHECK-NEXT: static constexpr bool isESIMD() { return 0; }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr bool callsThisItem() { return 0; }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr bool callsAnyThisFreeFunction() { return 0; }
// CHECK-NEXT:};
// CHECK-NEXT:template <> struct KernelInfo<class EMU> {
// CHECK-NEXT: __SYCL_DLL_LOCAL
Expand All @@ -43,6 +46,8 @@
// CHECK-NEXT: static constexpr bool isESIMD() { return 0; }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr bool callsThisItem() { return 1; }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr bool callsAnyThisFreeFunction() { return 1; }
// CHECK-NEXT:};
// CHECK-NEXT:template <> struct KernelInfo<class OWL> {
// CHECK-NEXT: __SYCL_DLL_LOCAL
Expand All @@ -57,6 +62,8 @@
// CHECK-NEXT: static constexpr bool isESIMD() { return 0; }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr bool callsThisItem() { return 0; }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr bool callsAnyThisFreeFunction() { return 0; }
// CHECK-NEXT:};
// CHECK-NEXT:template <> struct KernelInfo<class RAT> {
// CHECK-NEXT: __SYCL_DLL_LOCAL
Expand All @@ -71,6 +78,24 @@
// CHECK-NEXT: static constexpr bool isESIMD() { return 0; }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr bool callsThisItem() { return 1; }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr bool callsAnyThisFreeFunction() { return 1; }
// CHECK-NEXT:};
// CHECK-NEXT:template <> struct KernelInfo<class FOX> {
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr const char* getName() { return "_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE3FOX"; }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr unsigned getNumParams() { return 0; }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr const kernel_param_desc_t& getParamDesc(unsigned i) {
// CHECK-NEXT: return kernel_signatures[i+0];
// CHECK-NEXT: }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr bool isESIMD() { return 0; }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr bool callsThisItem() { return 0; }
// CHECK-NEXT: __SYCL_DLL_LOCAL
// CHECK-NEXT: static constexpr bool callsAnyThisFreeFunction() { return 1; }
// CHECK-NEXT:};

#include "sycl.hpp"
Expand Down Expand Up @@ -108,6 +133,10 @@ int main() {

// This kernel calls sycl::this_item
cgh.parallel_for<class RAT>(range<1>(1), [=](id<1> I) { f(); });

// This kernel does not call sycl::this_item, but does call this_id
cgh.parallel_for<class FOX>(range<1>(1),
[=](id<1> I) { this_id<1>(); });
});

return 0;
Expand Down
1 change: 1 addition & 0 deletions sycl/include/CL/sycl/ONEAPI/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <CL/sycl/ONEAPI/group_algorithm.hpp>
#include <CL/sycl/accessor.hpp>
#include <CL/sycl/handler.hpp>
#include <CL/sycl/kernel.hpp>

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {
Expand Down
48 changes: 36 additions & 12 deletions sycl/include/CL/sycl/detail/cg_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class HostTask {
};

// Class which stores specific lambda object.
template <class KernelType, class KernelArgType, int Dims>
template <class KernelType, class KernelArgType, int Dims, typename KernelName>
class HostKernel : public HostKernelBase {
using IDBuilder = sycl::detail::Builder;
KernelType MKernel;
Expand Down Expand Up @@ -203,6 +203,9 @@ class HostKernel : public HostKernelBase {
template <class ArgT = KernelArgType>
typename detail::enable_if_t<std::is_same<ArgT, sycl::id<Dims>>::value>
runOnHost(const NDRDescT &NDRDesc) {
using KI = detail::KernelInfo<KernelName>;
constexpr bool StoreLocation = KI::callsAnyThisFreeFunction();

sycl::range<Dims> Range(InitializedVal<Dims, range>::template get<0>());
sycl::id<Dims> Offset;
for (int I = 0; I < Dims; ++I) {
Expand All @@ -213,8 +216,11 @@ class HostKernel : public HostKernelBase {
detail::NDLoop<Dims>::iterate(Range, [&](const sycl::id<Dims> &ID) {
sycl::item<Dims, /*Offset=*/true> Item =
IDBuilder::createItem<Dims, true>(Range, ID, Offset);
store_id(&ID);
store_item(&Item);

if (StoreLocation) {
store_id(&ID);
store_item(&Item);
}
MKernel(ID);
});
}
Expand All @@ -223,6 +229,9 @@ class HostKernel : public HostKernelBase {
typename detail::enable_if_t<
std::is_same<ArgT, item<Dims, /*Offset=*/false>>::value>
runOnHost(const NDRDescT &NDRDesc) {
using KI = detail::KernelInfo<KernelName>;
constexpr bool StoreLocation = KI::callsAnyThisFreeFunction();

sycl::id<Dims> ID;
sycl::range<Dims> Range(InitializedVal<Dims, range>::template get<0>());
for (int I = 0; I < Dims; ++I)
Expand All @@ -232,8 +241,11 @@ class HostKernel : public HostKernelBase {
sycl::item<Dims, /*Offset=*/false> Item =
IDBuilder::createItem<Dims, false>(Range, ID);
sycl::item<Dims, /*Offset=*/true> ItemWithOffset = Item;
store_id(&ID);
store_item(&ItemWithOffset);

if (StoreLocation) {
store_id(&ID);
store_item(&ItemWithOffset);
}
MKernel(Item);
});
}
Expand All @@ -242,6 +254,9 @@ class HostKernel : public HostKernelBase {
typename detail::enable_if_t<
std::is_same<ArgT, item<Dims, /*Offset=*/true>>::value>
runOnHost(const NDRDescT &NDRDesc) {
using KI = detail::KernelInfo<KernelName>;
constexpr bool StoreLocation = KI::callsAnyThisFreeFunction();

sycl::range<Dims> Range(InitializedVal<Dims, range>::template get<0>());
sycl::id<Dims> Offset;
for (int I = 0; I < Dims; ++I) {
Expand All @@ -253,15 +268,21 @@ class HostKernel : public HostKernelBase {
sycl::id<Dims> OffsetID = ID + Offset;
sycl::item<Dims, /*Offset=*/true> Item =
IDBuilder::createItem<Dims, true>(Range, OffsetID, Offset);
store_id(&OffsetID);
store_item(&Item);

if (StoreLocation) {
store_id(&OffsetID);
store_item(&Item);
}
MKernel(Item);
});
}

template <class ArgT = KernelArgType>
typename detail::enable_if_t<std::is_same<ArgT, nd_item<Dims>>::value>
runOnHost(const NDRDescT &NDRDesc) {
using KI = detail::KernelInfo<KernelName>;
constexpr bool StoreLocation = KI::callsAnyThisFreeFunction();

sycl::range<Dims> GroupSize(InitializedVal<Dims, range>::template get<0>());
for (int I = 0; I < Dims; ++I) {
if (NDRDesc.LocalSize[I] == 0 ||
Expand Down Expand Up @@ -294,11 +315,14 @@ class HostKernel : public HostKernelBase {
IDBuilder::createItem<Dims, false>(LocalSize, LocalID);
const sycl::nd_item<Dims> NDItem =
IDBuilder::createNDItem<Dims>(GlobalItem, LocalItem, Group);
store_id(&GlobalID);
store_item(&GlobalItem);
store_nd_item(&NDItem);
auto g = NDItem.get_group();
store_group(&g);

if (StoreLocation) {
store_id(&GlobalID);
store_item(&GlobalItem);
store_nd_item(&NDItem);
auto g = NDItem.get_group();
store_group(&g);
}
MKernel(NDItem);
});
});
Expand Down
2 changes: 2 additions & 0 deletions sycl/include/CL/sycl/detail/kernel_desc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ template <class KernelNameType> struct KernelInfo {
static constexpr const char *getName() { return ""; }
static constexpr bool isESIMD() { return 0; }
static constexpr bool callsThisItem() { return false; }
static constexpr bool callsAnyThisFreeFunction() { return false; }
};
#else
template <char...> struct KernelInfoData {
Expand All @@ -69,6 +70,7 @@ template <char...> struct KernelInfoData {
static constexpr const char *getName() { return ""; }
static constexpr bool isESIMD() { return 0; }
static constexpr bool callsThisItem() { return false; }
static constexpr bool callsAnyThisFreeFunction() { return false; }
};

// C++14 like index_sequence and make_index_sequence
Expand Down
Loading