Skip to content

[SYCL] Support multiple call operators in kernel functor #8525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d172f64
[SYCL] Support multiple call operators in kernel
srividya-sundaram Mar 2, 2023
e17002d
Add code to handle multiple call ops in kernel functor.
srividya-sundaram Mar 7, 2023
4f2dc5b
Modify test file
srividya-sundaram Mar 7, 2023
75f1919
Fix failing test.
srividya-sundaram Mar 8, 2023
a8bd38c
Address review comments.
srividya-sundaram Mar 10, 2023
e73efaa
Re-write code using RecursiveASTVisitor.
srividya-sundaram Mar 15, 2023
caddf21
Update test.
srividya-sundaram Mar 16, 2023
53d8de0
Fix review comments.
srividya-sundaram Mar 17, 2023
ef3ccd3
Address review comments.
srividya-sundaram Mar 18, 2023
8cc706f
Address code review comments.
srividya-sundaram Mar 23, 2023
78c5bdc
Fix ESIMD code.
srividya-sundaram Mar 24, 2023
8ef7f47
Move ESIMD check to SyclKernelDeclCreator.
srividya-sundaram Mar 28, 2023
5548aae
Remove ESIMD init method generation from KernelArgsSizeChecker.
srividya-sundaram Mar 28, 2023
ab95f49
Fix review comments.
srividya-sundaram Mar 28, 2023
0680ce8
Add tests.
srividya-sundaram Mar 29, 2023
c971930
Remove extra line.
srividya-sundaram Mar 29, 2023
75a5c3f
Fix typos.
srividya-sundaram Mar 29, 2023
42bd239
Fix failing tests.
srividya-sundaram Mar 29, 2023
95c181c
Fix test failure.
srividya-sundaram Mar 30, 2023
e18a1ce
Add IsSIMD flag to SyclKernelBodyCreator.
srividya-sundaram Apr 2, 2023
4eead59
Add comments.
srividya-sundaram Apr 2, 2023
2cf8b6d
Modify copySYCLKernelAttrs signature to take reference to KernelCallO…
srividya-sundaram Apr 3, 2023
98c9839
Remove lambda check from Visitor class.
srividya-sundaram Apr 3, 2023
5361bd8
Pass Call operator to Visitor and functions
srividya-sundaram Apr 3, 2023
26ea0bd
Remove unnecessary code.
srividya-sundaram Apr 3, 2023
ae1dc6c
Add new ESIMD Visitor.
srividya-sundaram Apr 6, 2023
df14a30
Move ArgsSizeChecker to ConstructOpenCLKernel.
srividya-sundaram Apr 6, 2023
9f99f92
Fix test.
srividya-sundaram Apr 6, 2023
c41945f
Move ESIMD check after SYCL kernel check.
srividya-sundaram Apr 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -14274,7 +14274,8 @@ class Sema final {

bool isDeclAllowedInSYCLDeviceCode(const Decl *D);
void checkSYCLDeviceVarDecl(VarDecl *Var);
void copySYCLKernelAttrs(const CXXRecordDecl *KernelObj);
void copySYCLKernelAttrs(const CXXRecordDecl *KernelObj,
FunctionDecl *KernelCallerFunc);
void ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc, MangleContext &MC);
void SetSYCLKernelNames();
void MarkDevices();
Expand Down
123 changes: 104 additions & 19 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@

#include "TreeTransform.h"
#include "clang/AST/AST.h"
#include "clang/AST/AttrVisitor.h"
#include "clang/AST/DeclVisitor.h"
#include "clang/AST/Mangle.h"
#include "clang/AST/QualTypeNames.h"
#include "clang/AST/RecordLayout.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/AST/StmtVisitor.h"
#include "clang/AST/TemplateArgumentVisitor.h"
#include "clang/AST/TypeVisitor.h"
#include "clang/Analysis/CallGraph.h"
Expand Down Expand Up @@ -2784,16 +2787,77 @@ class SyclOptReportCreator : public SyclKernelFieldHandler {
}
};

static CXXMethodDecl *getOperatorParens(const CXXRecordDecl *Rec) {
for (auto *MD : Rec->methods()) {
if (MD->getOverloadedOperator() == OO_Call)
return MD;
// This Visitor traverses the AST of the function with
// `sycl_kernel` attribute and returns the version of “operator()()” that is
// called by kernelFunc(). There will only be one call to kernelFunc() in that
// AST because the DPC++ headers are structured such that the user’s
// kernel function is only called once. This ensures that the correct
// “operator()()” function call is returned, when a named function object used
// to define a kernel has more than one “operator()()” calls defined in it. For
// example, in the code below, 'operator()(sycl::id<1> id)' is returned based on
// the 'parallel_for' invocation which takes a 'sycl::range<1>(16)' argument.
// class MyKernel {
// public:
// void operator()() const {
// // code
// }
//
// [[intel::reqd_sub_group_size(4)]] void operator()(sycl::id<1> id) const
// {
// // code
// }
// };
//
// int main() {
//
// Q.submit([&](sycl::handler& cgh) {
// MyKernel kernelFunctorObject;
// cgh.parallel_for(sycl::range<1>(16), kernelFunctorObject);
// });
// return 0;
// }

class KernelCallOperatorVisitor
: public RecursiveASTVisitor<KernelCallOperatorVisitor> {

FunctionDecl *KernelCallerFunc;

public:
CXXMethodDecl *CallOperator = nullptr;
const CXXRecordDecl *KernelObj;

KernelCallOperatorVisitor(FunctionDecl *KernelCallerFunc,
const CXXRecordDecl *KernelObj)
: KernelCallerFunc(KernelCallerFunc), KernelObj(KernelObj) {}

bool VisitCallExpr(CallExpr *CE) {
Decl *CalleeDecl = CE->getCalleeDecl();
if (isa_and_nonnull<CXXMethodDecl>(CalleeDecl)) {
CXXMethodDecl *MD = cast<CXXMethodDecl>(CalleeDecl);
if (MD->getOverloadedOperator() == OO_Call &&
MD->getParent() == KernelObj) {
CallOperator = MD;
}
}
return true;
}
};

static bool isESIMDKernelType(KernelCallOperatorVisitor KernelCallOperator,
const CXXRecordDecl *KernelObjType,
FunctionDecl *KernelCallerFunc, Sema &SemaRef) {
const CXXMethodDecl *OpParens = nullptr;

if (KernelObjType->isLambda()) {
for (auto *MD : KernelObjType->methods()) {
if (MD->getOverloadedOperator() == OO_Call)
OpParens = MD;
}
} else {
KernelCallOperator.TraverseDecl(KernelCallerFunc);
OpParens = KernelCallOperator.CallOperator;
}
return nullptr;
}

static bool isESIMDKernelType(const CXXRecordDecl *KernelObjType) {
const CXXMethodDecl *OpParens = getOperatorParens(KernelObjType);
return (OpParens != nullptr) && OpParens->hasAttr<SYCLSimdAttr>();
}

Expand Down Expand Up @@ -2869,6 +2933,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
}

void annotateHierarchicalParallelismAPICalls() {

// Is this a hierarchical parallelism kernel invocation?
if (getKernelInvocationKind(KernelCallerFunc) != InvokeParallelForWorkGroup)
return;
Expand All @@ -2882,11 +2947,14 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
// (of either the lambda or the function object).
CXXRecordDecl *KernelObj =
GetSYCLKernelObjectType(KernelCallerFunc)->getAsCXXRecordDecl();

KernelCallOperatorVisitor KernelCallOperator(KernelCallerFunc, KernelObj);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically the isLambda check can be done here as well. You don't really need to construct this object for lambdas.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand how this suggestion was applied, we still construct the object, but there is also an additional check on Lambda...

KernelCallOperator.TraverseDecl(KernelCallerFunc);
CXXMethodDecl *WGLambdaFn = nullptr;
if (KernelObj->isLambda())
WGLambdaFn = KernelObj->getLambdaCallOperator();
else
WGLambdaFn = getOperatorParens(KernelObj);
WGLambdaFn = KernelCallOperator.CallOperator;
assert(WGLambdaFn && "non callable object is passed as kernel obj");
// Mark the function that it "works" in a work group scope:
// NOTE: In case of parallel_for_work_item the marker call itself is
Expand Down Expand Up @@ -3199,7 +3267,10 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
}

const llvm::StringLiteral getInitMethodName() const {
bool IsSIMDKernel = isESIMDKernelType(KernelObj);
KernelCallOperatorVisitor KernelCallOperator(KernelCallerFunc, KernelObj);

bool IsSIMDKernel = isESIMDKernelType(KernelCallOperator, KernelObj,
KernelCallerFunc, SemaRef);
return IsSIMDKernel ? InitESIMDMethodName : InitMethodName;
}

Expand Down Expand Up @@ -3550,6 +3621,7 @@ static bool IsSYCLUnnamedKernel(Sema &SemaRef, const FunctionDecl *FD) {

class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
SYCLIntegrationHeader &Header;
KernelCallOperatorVisitor KernelCallOperator;
int64_t CurOffset = 0;
llvm::SmallVector<size_t, 16> ArrayBaseOffsets;
int StructDepth = 0;
Expand Down Expand Up @@ -3581,11 +3653,14 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {

public:
static constexpr const bool VisitInsideSimpleContainers = false;
SyclKernelIntHeaderCreator(Sema &S, SYCLIntegrationHeader &H,
SyclKernelIntHeaderCreator(KernelCallOperatorVisitor KernelCallOperator,
Sema &S, SYCLIntegrationHeader &H,
const CXXRecordDecl *KernelObj, QualType NameType,
FunctionDecl *KernelFunc)
: SyclKernelFieldHandler(S), Header(H) {
bool IsSIMDKernel = isESIMDKernelType(KernelObj);
: SyclKernelFieldHandler(S), Header(H),
KernelCallOperator(KernelCallOperator) {
bool IsSIMDKernel =
isESIMDKernelType(KernelCallOperator, KernelObj, KernelFunc, S);
// The header needs to access the kernel object size.
int64_t ObjSize = SemaRef.getASTContext()
.getTypeSizeInChars(KernelObj->getTypeForDecl())
Expand Down Expand Up @@ -3969,6 +4044,9 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc,
const CXXRecordDecl *KernelObj =
GetSYCLKernelObjectType(KernelFunc)->getAsCXXRecordDecl();

KernelCallOperatorVisitor KernelCallOperator(KernelFunc, KernelObj);
KernelCallOperator.TraverseDecl(KernelFunc);

if (!KernelObj) {
Diag(Args[0]->getExprLoc(), diag::err_sycl_kernel_not_function_object);
KernelFunc->setInvalidDecl();
Expand Down Expand Up @@ -3999,7 +4077,8 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc,
if (KernelObj->isInvalidDecl())
return;

bool IsSIMDKernel = isESIMDKernelType(KernelObj);
bool IsSIMDKernel =
isESIMDKernelType(KernelCallOperator, KernelObj, KernelFunc, *this);

SyclKernelDecompMarker DecompMarker(*this);
SyclKernelFieldChecker FieldChecker(*this, IsSIMDKernel);
Expand Down Expand Up @@ -4033,9 +4112,12 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc,

// For a wrapped parallel_for, copy attributes from original
// kernel to wrapped kernel.
void Sema::copySYCLKernelAttrs(const CXXRecordDecl *KernelObj) {
void Sema::copySYCLKernelAttrs(const CXXRecordDecl *KernelObj,
FunctionDecl *KernelCallerFunc) {
// Get the operator() function of the wrapper.
CXXMethodDecl *OpParens = getOperatorParens(KernelObj);
KernelCallOperatorVisitor KernelCallOperator(KernelCallerFunc, KernelObj);
KernelCallOperator.TraverseDecl(KernelCallerFunc);
CXXMethodDecl *OpParens = KernelCallOperator.CallOperator;
assert(OpParens && "invalid kernel object");

typedef std::pair<FunctionDecl *, FunctionDecl *> ChildParentPair;
Expand Down Expand Up @@ -4148,18 +4230,21 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
// Attributes of a user-written SYCL kernel must be copied to the internally
// generated alternative kernel, identified by a known string in its name.
if (StableName.find("__pf_kernel_wrapper") != std::string::npos)
copySYCLKernelAttrs(KernelObj);
copySYCLKernelAttrs(KernelObj, KernelCallerFunc);
}

bool IsSIMDKernel = isESIMDKernelType(KernelObj);
KernelCallOperatorVisitor KernelCallOperator(KernelCallerFunc, KernelObj);

bool IsSIMDKernel =
isESIMDKernelType(KernelCallOperator, KernelObj, KernelCallerFunc, *this);

SyclKernelDeclCreator kernel_decl(*this, KernelObj->getLocation(),
KernelCallerFunc->isInlined(), IsSIMDKernel,
KernelCallerFunc);
SyclKernelBodyCreator kernel_body(*this, kernel_decl, KernelObj,
KernelCallerFunc);
SyclKernelIntHeaderCreator int_header(
*this, getSyclIntegrationHeader(), KernelObj,
KernelCallOperator, *this, getSyclIntegrationHeader(), KernelObj,
calculateKernelNameType(Context, KernelCallerFunc), KernelCallerFunc);

SyclKernelIntFooterCreator int_footer(*this, getSyclIntegrationFooter());
Expand Down
7 changes: 7 additions & 0 deletions clang/test/CodeGenSYCL/Inputs/sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,13 @@ kernel_parallel_for_work_group(const KernelType &KernelFunc) {

class handler {
public:

template <typename KernelName = auto_name, typename KernelType>
void parallel_for(const KernelType &kernelObj) {
using NameT = typename get_kernel_name_t<KernelName, KernelType>::name;
kernel_parallel_for<NameT>(kernelObj);
}

template <typename KernelName = auto_name, typename KernelType, int Dims>
void parallel_for(range<Dims> numWorkItems, const KernelType &kernelFunc) {
using NameT = typename get_kernel_name_t<KernelName, KernelType>::name;
Expand Down
46 changes: 46 additions & 0 deletions clang/test/CodeGenSYCL/kernel-op-calls.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s
// This test checks that the correct kernel operator call is invoked when there are multiple definitions of the
// 'operator()()' call.

#include "sycl.hpp"

sycl::queue Q;

// Check if functor with multiple call operator works.
class Functor1 {
public:
Functor1(){}

[[intel::reqd_sub_group_size(4)]] void operator()(sycl::id<1> id) const {}

[[sycl::work_group_size_hint(1, 2, 3)]] void operator()(sycl::id<2> id) const {}

};

// Check templated 'operator()()' call works.
class kernels {
public:
kernels(){}

template<int Dimensions = 1>
[[sycl::work_group_size_hint(1, 2, 3)]] void operator()(sycl::id<Dimensions> item) const {}

};

int main() {

Q.submit([&](sycl::handler& cgh) {
Functor1 F;
// CHECK: define dso_local spir_kernel void @_ZTS8Functor1() #0 !srcloc !11 !kernel_arg_buffer_location !12 !intel_reqd_sub_group_size !13 !sycl_fixed_targets !12 {
cgh.parallel_for(sycl::range<1>(10), F);
});

Q.submit([&](sycl::handler& cgh) {
kernels K;
// CHECK: define dso_local spir_kernel void @_ZTS7kernels() #0 !srcloc !15 !kernel_arg_buffer_location !12 !work_group_size_hint !16 !sycl_fixed_targets !12 {
cgh.parallel_for(sycl::range<1>(10), K);
});

return 0;
}