@@ -988,6 +988,28 @@ static QualType GetSYCLKernelObjectType(const FunctionDecl *KernelCaller) {
988
988
return KernelParamTy.getUnqualifiedType ();
989
989
}
990
990
991
+ static CXXMethodDecl *getOperatorParens (const CXXRecordDecl *Rec) {
992
+ for (auto *MD : Rec->methods ()) {
993
+ if (MD->getOverloadedOperator () == OO_Call)
994
+ return MD;
995
+ }
996
+ return nullptr ;
997
+ }
998
+
999
+ // Fetch the associated call operator of the kernel object
1000
+ // (of either the lambda or the function object).
1001
+ static CXXMethodDecl *
1002
+ GetCallOperatorOfKernelObject (const CXXRecordDecl *KernelObjType) {
1003
+ CXXMethodDecl *CallOperator = nullptr ;
1004
+ if (!KernelObjType)
1005
+ return CallOperator;
1006
+ if (KernelObjType->isLambda ())
1007
+ CallOperator = KernelObjType->getLambdaCallOperator ();
1008
+ else
1009
+ CallOperator = getOperatorParens (KernelObjType);
1010
+ return CallOperator;
1011
+ }
1012
+
991
1013
// / Creates a kernel parameter descriptor
992
1014
// / \param Src field declaration to construct name from
993
1015
// / \param Ty the desired parameter type
@@ -2775,14 +2797,6 @@ class SyclOptReportCreator : public SyclKernelFieldHandler {
2775
2797
}
2776
2798
};
2777
2799
2778
- static CXXMethodDecl *getOperatorParens (const CXXRecordDecl *Rec) {
2779
- for (auto *MD : Rec->methods ()) {
2780
- if (MD->getOverloadedOperator () == OO_Call)
2781
- return MD;
2782
- }
2783
- return nullptr ;
2784
- }
2785
-
2786
2800
static bool isESIMDKernelType (const CXXRecordDecl *KernelObjType) {
2787
2801
const CXXMethodDecl *OpParens = getOperatorParens (KernelObjType);
2788
2802
return (OpParens != nullptr ) && OpParens->hasAttr <SYCLSimdAttr>();
@@ -2871,13 +2885,10 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
2871
2885
2872
2886
// Fetch the kernel object and the associated call operator
2873
2887
// (of either the lambda or the function object).
2874
- CXXRecordDecl *KernelObj =
2888
+ const CXXRecordDecl *KernelObj =
2875
2889
GetSYCLKernelObjectType (KernelCallerFunc)->getAsCXXRecordDecl ();
2876
- CXXMethodDecl *WGLambdaFn = nullptr ;
2877
- if (KernelObj->isLambda ())
2878
- WGLambdaFn = KernelObj->getLambdaCallOperator ();
2879
- else
2880
- WGLambdaFn = getOperatorParens (KernelObj);
2890
+ CXXMethodDecl *WGLambdaFn = GetCallOperatorOfKernelObject (KernelObj);
2891
+
2881
2892
assert (WGLambdaFn && " non callable object is passed as kernel obj" );
2882
2893
// Mark the function that it "works" in a work group scope:
2883
2894
// NOTE: In case of parallel_for_work_item the marker call itself is
@@ -3534,7 +3545,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
3534
3545
static bool IsSYCLUnnamedKernel (Sema &SemaRef, const FunctionDecl *FD) {
3535
3546
if (!SemaRef.getLangOpts ().SYCLUnnamedLambda )
3536
3547
return false ;
3537
- QualType FunctorTy = GetSYCLKernelObjectType (FD);
3548
+ const QualType FunctorTy = GetSYCLKernelObjectType (FD);
3538
3549
QualType TmplArgTy = calculateKernelNameType (SemaRef.Context , FD);
3539
3550
return SemaRef.Context .hasSameType (FunctorTy, TmplArgTy);
3540
3551
}
@@ -3960,7 +3971,7 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc,
3960
3971
const CXXRecordDecl *KernelObj =
3961
3972
GetSYCLKernelObjectType (KernelFunc)->getAsCXXRecordDecl ();
3962
3973
3963
- if (!KernelObj) {
3974
+ if (!GetCallOperatorOfKernelObject ( KernelObj) ) {
3964
3975
Diag (Args[0 ]->getExprLoc (), diag::err_sycl_kernel_not_function_object);
3965
3976
KernelFunc->setInvalidDecl ();
3966
3977
return ;
0 commit comments