Skip to content

Commit 490ee55

Browse files
[SYCL] Fix crash caused by functor without call operator as Kernel (#7104)
When a named function object without a call operator defined is used as a kernel, the compiler causes a `crash`. With the changes in this PR, a compiler `error` is thrown when a lambda or function object without a call operator defined is used as a kernel. Co-authored-by: premanandrao <[email protected]>
1 parent c49eeda commit 490ee55

File tree

2 files changed

+67
-16
lines changed

2 files changed

+67
-16
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,28 @@ static QualType GetSYCLKernelObjectType(const FunctionDecl *KernelCaller) {
988988
return KernelParamTy.getUnqualifiedType();
989989
}
990990

991+
static CXXMethodDecl *getOperatorParens(const CXXRecordDecl *Rec) {
992+
for (auto *MD : Rec->methods()) {
993+
if (MD->getOverloadedOperator() == OO_Call)
994+
return MD;
995+
}
996+
return nullptr;
997+
}
998+
999+
// Fetch the associated call operator of the kernel object
1000+
// (of either the lambda or the function object).
1001+
static CXXMethodDecl *
1002+
GetCallOperatorOfKernelObject(const CXXRecordDecl *KernelObjType) {
1003+
CXXMethodDecl *CallOperator = nullptr;
1004+
if (!KernelObjType)
1005+
return CallOperator;
1006+
if (KernelObjType->isLambda())
1007+
CallOperator = KernelObjType->getLambdaCallOperator();
1008+
else
1009+
CallOperator = getOperatorParens(KernelObjType);
1010+
return CallOperator;
1011+
}
1012+
9911013
/// Creates a kernel parameter descriptor
9921014
/// \param Src field declaration to construct name from
9931015
/// \param Ty the desired parameter type
@@ -2775,14 +2797,6 @@ class SyclOptReportCreator : public SyclKernelFieldHandler {
27752797
}
27762798
};
27772799

2778-
static CXXMethodDecl *getOperatorParens(const CXXRecordDecl *Rec) {
2779-
for (auto *MD : Rec->methods()) {
2780-
if (MD->getOverloadedOperator() == OO_Call)
2781-
return MD;
2782-
}
2783-
return nullptr;
2784-
}
2785-
27862800
static bool isESIMDKernelType(const CXXRecordDecl *KernelObjType) {
27872801
const CXXMethodDecl *OpParens = getOperatorParens(KernelObjType);
27882802
return (OpParens != nullptr) && OpParens->hasAttr<SYCLSimdAttr>();
@@ -2871,13 +2885,10 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
28712885

28722886
// Fetch the kernel object and the associated call operator
28732887
// (of either the lambda or the function object).
2874-
CXXRecordDecl *KernelObj =
2888+
const CXXRecordDecl *KernelObj =
28752889
GetSYCLKernelObjectType(KernelCallerFunc)->getAsCXXRecordDecl();
2876-
CXXMethodDecl *WGLambdaFn = nullptr;
2877-
if (KernelObj->isLambda())
2878-
WGLambdaFn = KernelObj->getLambdaCallOperator();
2879-
else
2880-
WGLambdaFn = getOperatorParens(KernelObj);
2890+
CXXMethodDecl *WGLambdaFn = GetCallOperatorOfKernelObject(KernelObj);
2891+
28812892
assert(WGLambdaFn && "non callable object is passed as kernel obj");
28822893
// Mark the function that it "works" in a work group scope:
28832894
// NOTE: In case of parallel_for_work_item the marker call itself is
@@ -3534,7 +3545,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
35343545
static bool IsSYCLUnnamedKernel(Sema &SemaRef, const FunctionDecl *FD) {
35353546
if (!SemaRef.getLangOpts().SYCLUnnamedLambda)
35363547
return false;
3537-
QualType FunctorTy = GetSYCLKernelObjectType(FD);
3548+
const QualType FunctorTy = GetSYCLKernelObjectType(FD);
35383549
QualType TmplArgTy = calculateKernelNameType(SemaRef.Context, FD);
35393550
return SemaRef.Context.hasSameType(FunctorTy, TmplArgTy);
35403551
}
@@ -3960,7 +3971,7 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc,
39603971
const CXXRecordDecl *KernelObj =
39613972
GetSYCLKernelObjectType(KernelFunc)->getAsCXXRecordDecl();
39623973

3963-
if (!KernelObj) {
3974+
if (!GetCallOperatorOfKernelObject(KernelObj)) {
39643975
Diag(Args[0]->getExprLoc(), diag::err_sycl_kernel_not_function_object);
39653976
KernelFunc->setInvalidDecl();
39663977
return;
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -sycl-std=2020 -verify -fsyntax-only %s
2+
// This test checks that an error is thrown when a functor without a call operator defined is used as a kernel.
3+
4+
#include "sycl.hpp"
5+
6+
using namespace sycl;
7+
queue q;
8+
9+
struct FunctorWithoutCallOperator; // expected-note {{forward declaration of 'FunctorWithoutCallOperator'}}
10+
11+
struct StructDefined {
12+
int x;
13+
};
14+
15+
class FunctorWithCallOpDefined {
16+
int x;
17+
public:
18+
void operator()() const {}
19+
};
20+
21+
int main() {
22+
23+
q.submit([&](sycl::handler &cgh) {
24+
// expected-error@#KernelSingleTask {{kernel parameter must be a lambda or function object}}
25+
// expected-error@+2 {{invalid use of incomplete type 'FunctorWithoutCallOperator'}}
26+
// expected-note@+1 {{in instantiation of function template specialization}}
27+
cgh.single_task(FunctorWithoutCallOperator{});
28+
});
29+
30+
q.submit([&](sycl::handler &cgh) {
31+
// expected-error@#KernelSingleTask {{kernel parameter must be a lambda or function object}}
32+
// expected-note@+1 {{in instantiation of function template specialization}}
33+
cgh.single_task(StructDefined{});
34+
});
35+
36+
q.submit([&](sycl::handler &cgh) {
37+
cgh.single_task(FunctorWithCallOpDefined{});
38+
});
39+
40+
}

0 commit comments

Comments
 (0)