-
Notifications
You must be signed in to change notification settings - Fork 787
[SYCL] Support multiple call operators in kernel functor #8525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
d172f64
e17002d
4f2dc5b
75f1919
a8bd38c
e73efaa
caddf21
53d8de0
ef3ccd3
8cc706f
78c5bdc
8ef7f47
5548aae
ab95f49
0680ce8
c971930
75a5c3f
42bd239
95c181c
e18a1ce
4eead59
2cf8b6d
98c9839
5361bd8
26ea0bd
ae1dc6c
df14a30
9f99f92
c41945f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,10 +10,13 @@ | |
|
||
#include "TreeTransform.h" | ||
#include "clang/AST/AST.h" | ||
#include "clang/AST/AttrVisitor.h" | ||
#include "clang/AST/DeclVisitor.h" | ||
#include "clang/AST/Mangle.h" | ||
#include "clang/AST/QualTypeNames.h" | ||
#include "clang/AST/RecordLayout.h" | ||
#include "clang/AST/RecursiveASTVisitor.h" | ||
#include "clang/AST/StmtVisitor.h" | ||
srividya-sundaram marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#include "clang/AST/TemplateArgumentVisitor.h" | ||
#include "clang/AST/TypeVisitor.h" | ||
#include "clang/Analysis/CallGraph.h" | ||
|
@@ -2784,16 +2787,77 @@ class SyclOptReportCreator : public SyclKernelFieldHandler { | |
} | ||
}; | ||
|
||
static CXXMethodDecl *getOperatorParens(const CXXRecordDecl *Rec) { | ||
for (auto *MD : Rec->methods()) { | ||
if (MD->getOverloadedOperator() == OO_Call) | ||
return MD; | ||
// This Visitor traverses the AST of the function with | ||
// `sycl_kernel` attribute and returns the version of “operator()()” that is | ||
// called by kernelFunc(). There will only be one call to kernelFunc() in that | ||
srividya-sundaram marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// AST because the DPC++ headers are structured such that the user’s | ||
// kernel function is only called once. This ensures that the correct | ||
// “operator()()” function call is returned, when a named function object used | ||
// to define a kernel has more than one “operator()()” calls defined in it. For | ||
// example, in the code below, 'operator()(sycl::id<1> id)' is returned based on | ||
// the 'parallel_for' invocation which takes a 'sycl::range<1>(16)' argument. | ||
// class MyKernel { | ||
// public: | ||
// void operator()() const { | ||
// // code | ||
// } | ||
// | ||
// [[intel::reqd_sub_group_size(4)]] void operator()(sycl::id<1> id) const | ||
// { | ||
// // code | ||
// } | ||
// }; | ||
// | ||
// int main() { | ||
// | ||
// Q.submit([&](sycl::handler& cgh) { | ||
// MyKernel kernelFunctorObject; | ||
// cgh.parallel_for(sycl::range<1>(16), kernelFunctorObject); | ||
// }); | ||
// return 0; | ||
// } | ||
|
||
class KernelCallOperatorVisitor | ||
: public RecursiveASTVisitor<KernelCallOperatorVisitor> { | ||
|
||
FunctionDecl *KernelCallerFunc; | ||
|
||
public: | ||
CXXMethodDecl *CallOperator = nullptr; | ||
const CXXRecordDecl *KernelObj; | ||
|
||
KernelCallOperatorVisitor(FunctionDecl *KernelCallerFunc, | ||
const CXXRecordDecl *KernelObj) | ||
: KernelCallerFunc(KernelCallerFunc), KernelObj(KernelObj) {} | ||
|
||
bool VisitCallExpr(CallExpr *CE) { | ||
Decl *CalleeDecl = CE->getCalleeDecl(); | ||
if (isa_and_nonnull<CXXMethodDecl>(CalleeDecl)) { | ||
CXXMethodDecl *MD = cast<CXXMethodDecl>(CalleeDecl); | ||
if (MD->getOverloadedOperator() == OO_Call && | ||
MD->getParent() == KernelObj) { | ||
CallOperator = MD; | ||
} | ||
} | ||
return true; | ||
} | ||
srividya-sundaram marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}; | ||
|
||
static bool isESIMDKernelType(KernelCallOperatorVisitor KernelCallOperator, | ||
const CXXRecordDecl *KernelObjType, | ||
FunctionDecl *KernelCallerFunc, Sema &SemaRef) { | ||
srividya-sundaram marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const CXXMethodDecl *OpParens = nullptr; | ||
|
||
if (KernelObjType->isLambda()) { | ||
for (auto *MD : KernelObjType->methods()) { | ||
if (MD->getOverloadedOperator() == OO_Call) | ||
OpParens = MD; | ||
} | ||
srividya-sundaram marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} else { | ||
KernelCallOperator.TraverseDecl(KernelCallerFunc); | ||
OpParens = KernelCallOperator.CallOperator; | ||
} | ||
return nullptr; | ||
} | ||
|
||
static bool isESIMDKernelType(const CXXRecordDecl *KernelObjType) { | ||
const CXXMethodDecl *OpParens = getOperatorParens(KernelObjType); | ||
return (OpParens != nullptr) && OpParens->hasAttr<SYCLSimdAttr>(); | ||
} | ||
|
||
|
@@ -2869,6 +2933,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler { | |
} | ||
|
||
void annotateHierarchicalParallelismAPICalls() { | ||
|
||
srividya-sundaram marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// Is this a hierarchical parallelism kernel invocation? | ||
if (getKernelInvocationKind(KernelCallerFunc) != InvokeParallelForWorkGroup) | ||
return; | ||
|
@@ -2882,11 +2947,14 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler { | |
// (of either the lambda or the function object). | ||
CXXRecordDecl *KernelObj = | ||
GetSYCLKernelObjectType(KernelCallerFunc)->getAsCXXRecordDecl(); | ||
|
||
KernelCallOperatorVisitor KernelCallOperator(KernelCallerFunc, KernelObj); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Technically the isLambda check can be done here as well. You don't really need to construct this object for lambdas. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I understand how this suggestion was applied, we still construct the object, but there is also an additional check on Lambda... |
||
KernelCallOperator.TraverseDecl(KernelCallerFunc); | ||
srividya-sundaram marked this conversation as resolved.
Show resolved
Hide resolved
|
||
CXXMethodDecl *WGLambdaFn = nullptr; | ||
if (KernelObj->isLambda()) | ||
WGLambdaFn = KernelObj->getLambdaCallOperator(); | ||
else | ||
WGLambdaFn = getOperatorParens(KernelObj); | ||
WGLambdaFn = KernelCallOperator.CallOperator; | ||
assert(WGLambdaFn && "non callable object is passed as kernel obj"); | ||
// Mark the function that it "works" in a work group scope: | ||
// NOTE: In case of parallel_for_work_item the marker call itself is | ||
|
@@ -3199,7 +3267,10 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler { | |
} | ||
|
||
const llvm::StringLiteral getInitMethodName() const { | ||
bool IsSIMDKernel = isESIMDKernelType(KernelObj); | ||
KernelCallOperatorVisitor KernelCallOperator(KernelCallerFunc, KernelObj); | ||
srividya-sundaram marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
bool IsSIMDKernel = isESIMDKernelType(KernelCallOperator, KernelObj, | ||
KernelCallerFunc, SemaRef); | ||
return IsSIMDKernel ? InitESIMDMethodName : InitMethodName; | ||
} | ||
|
||
|
@@ -3550,6 +3621,7 @@ static bool IsSYCLUnnamedKernel(Sema &SemaRef, const FunctionDecl *FD) { | |
|
||
class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { | ||
SYCLIntegrationHeader &Header; | ||
KernelCallOperatorVisitor KernelCallOperator; | ||
int64_t CurOffset = 0; | ||
llvm::SmallVector<size_t, 16> ArrayBaseOffsets; | ||
int StructDepth = 0; | ||
|
@@ -3581,11 +3653,14 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler { | |
|
||
public: | ||
static constexpr const bool VisitInsideSimpleContainers = false; | ||
SyclKernelIntHeaderCreator(Sema &S, SYCLIntegrationHeader &H, | ||
SyclKernelIntHeaderCreator(KernelCallOperatorVisitor KernelCallOperator, | ||
Sema &S, SYCLIntegrationHeader &H, | ||
const CXXRecordDecl *KernelObj, QualType NameType, | ||
FunctionDecl *KernelFunc) | ||
: SyclKernelFieldHandler(S), Header(H) { | ||
bool IsSIMDKernel = isESIMDKernelType(KernelObj); | ||
: SyclKernelFieldHandler(S), Header(H), | ||
KernelCallOperator(KernelCallOperator) { | ||
bool IsSIMDKernel = | ||
isESIMDKernelType(KernelCallOperator, KernelObj, KernelFunc, S); | ||
// The header needs to access the kernel object size. | ||
int64_t ObjSize = SemaRef.getASTContext() | ||
.getTypeSizeInChars(KernelObj->getTypeForDecl()) | ||
|
@@ -3969,6 +4044,9 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, | |
const CXXRecordDecl *KernelObj = | ||
GetSYCLKernelObjectType(KernelFunc)->getAsCXXRecordDecl(); | ||
|
||
KernelCallOperatorVisitor KernelCallOperator(KernelFunc, KernelObj); | ||
KernelCallOperator.TraverseDecl(KernelFunc); | ||
srividya-sundaram marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if (!KernelObj) { | ||
Diag(Args[0]->getExprLoc(), diag::err_sycl_kernel_not_function_object); | ||
KernelFunc->setInvalidDecl(); | ||
|
@@ -3999,7 +4077,8 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, | |
if (KernelObj->isInvalidDecl()) | ||
return; | ||
|
||
bool IsSIMDKernel = isESIMDKernelType(KernelObj); | ||
bool IsSIMDKernel = | ||
isESIMDKernelType(KernelCallOperator, KernelObj, KernelFunc, *this); | ||
srividya-sundaram marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
SyclKernelDecompMarker DecompMarker(*this); | ||
SyclKernelFieldChecker FieldChecker(*this, IsSIMDKernel); | ||
|
@@ -4033,9 +4112,12 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, | |
|
||
// For a wrapped parallel_for, copy attributes from original | ||
// kernel to wrapped kernel. | ||
void Sema::copySYCLKernelAttrs(const CXXRecordDecl *KernelObj) { | ||
void Sema::copySYCLKernelAttrs(const CXXRecordDecl *KernelObj, | ||
FunctionDecl *KernelCallerFunc) { | ||
// Get the operator() function of the wrapper. | ||
CXXMethodDecl *OpParens = getOperatorParens(KernelObj); | ||
KernelCallOperatorVisitor KernelCallOperator(KernelCallerFunc, KernelObj); | ||
KernelCallOperator.TraverseDecl(KernelCallerFunc); | ||
srividya-sundaram marked this conversation as resolved.
Show resolved
Hide resolved
|
||
CXXMethodDecl *OpParens = KernelCallOperator.CallOperator; | ||
assert(OpParens && "invalid kernel object"); | ||
|
||
typedef std::pair<FunctionDecl *, FunctionDecl *> ChildParentPair; | ||
|
@@ -4148,18 +4230,21 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc, | |
// Attributes of a user-written SYCL kernel must be copied to the internally | ||
// generated alternative kernel, identified by a known string in its name. | ||
if (StableName.find("__pf_kernel_wrapper") != std::string::npos) | ||
copySYCLKernelAttrs(KernelObj); | ||
copySYCLKernelAttrs(KernelObj, KernelCallerFunc); | ||
srividya-sundaram marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
bool IsSIMDKernel = isESIMDKernelType(KernelObj); | ||
KernelCallOperatorVisitor KernelCallOperator(KernelCallerFunc, KernelObj); | ||
srividya-sundaram marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
bool IsSIMDKernel = | ||
isESIMDKernelType(KernelCallOperator, KernelObj, KernelCallerFunc, *this); | ||
|
||
SyclKernelDeclCreator kernel_decl(*this, KernelObj->getLocation(), | ||
KernelCallerFunc->isInlined(), IsSIMDKernel, | ||
KernelCallerFunc); | ||
SyclKernelBodyCreator kernel_body(*this, kernel_decl, KernelObj, | ||
KernelCallerFunc); | ||
SyclKernelIntHeaderCreator int_header( | ||
*this, getSyclIntegrationHeader(), KernelObj, | ||
KernelCallOperator, *this, getSyclIntegrationHeader(), KernelObj, | ||
srividya-sundaram marked this conversation as resolved.
Show resolved
Hide resolved
|
||
calculateKernelNameType(Context, KernelCallerFunc), KernelCallerFunc); | ||
|
||
SyclKernelIntFooterCreator int_footer(*this, getSyclIntegrationFooter()); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s | ||
// This test checks that the correct kernel operator call is invoked when there are multiple definitions of the | ||
// 'operator()()' call. | ||
|
||
#include "sycl.hpp" | ||
|
||
sycl::queue Q; | ||
|
||
// Check if functor with multiple call operator works. | ||
class Functor1 { | ||
public: | ||
Functor1(){} | ||
|
||
[[intel::reqd_sub_group_size(4)]] void operator()(sycl::id<1> id) const {} | ||
|
||
[[sycl::work_group_size_hint(1, 2, 3)]] void operator()(sycl::id<2> id) const {} | ||
|
||
}; | ||
|
||
// Check templated 'operator()()' call works. | ||
class kernels { | ||
public: | ||
kernels(){} | ||
|
||
template<int Dimensions = 1> | ||
[[sycl::work_group_size_hint(1, 2, 3)]] void operator()(sycl::id<Dimensions> item) const {} | ||
|
||
}; | ||
|
||
int main() { | ||
|
||
Q.submit([&](sycl::handler& cgh) { | ||
Functor1 F; | ||
// CHECK: define dso_local spir_kernel void @_ZTS8Functor1() #0 !srcloc !11 !kernel_arg_buffer_location !12 !intel_reqd_sub_group_size !13 !sycl_fixed_targets !12 { | ||
cgh.parallel_for(sycl::range<1>(10), F); | ||
}); | ||
|
||
Q.submit([&](sycl::handler& cgh) { | ||
kernels K; | ||
// CHECK: define dso_local spir_kernel void @_ZTS7kernels() #0 !srcloc !15 !kernel_arg_buffer_location !12 !work_group_size_hint !16 !sycl_fixed_targets !12 { | ||
cgh.parallel_for(sycl::range<1>(10), K); | ||
}); | ||
|
||
return 0; | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.