@@ -914,42 +914,6 @@ class KernelBodyTransform : public TreeTransform<KernelBodyTransform> {
914
914
Sema &SemaRef;
915
915
};
916
916
917
- // Searches for a call to PFWG lambda function and captures it.
918
- class FindPFWGLambdaFnVisitor
919
- : public RecursiveASTVisitor<FindPFWGLambdaFnVisitor> {
920
- public:
921
- // LambdaObjTy - lambda type of the PFWG lambda object
922
- FindPFWGLambdaFnVisitor (const CXXRecordDecl *LambdaObjTy)
923
- : LambdaFn(nullptr ), LambdaObjTy(LambdaObjTy) {}
924
-
925
- bool VisitCallExpr (CallExpr *Call) {
926
- auto *M = dyn_cast<CXXMethodDecl>(Call->getDirectCallee ());
927
- if (!M || (M->getOverloadedOperator () != OO_Call))
928
- return true ;
929
-
930
- unsigned int NumPFWGLambdaArgs =
931
- M->getNumParams () + 1 ; // group, optional kernel_handler and lambda obj
932
- if (Call->getNumArgs () != NumPFWGLambdaArgs)
933
- return true ;
934
- if (!Util::isSyclType (Call->getArg (1 )->getType (), " group" , true /* Tmpl*/ ))
935
- return true ;
936
- if ((Call->getNumArgs () > 2 ) &&
937
- !Util::isSyclKernelHandlerType (Call->getArg (2 )->getType ()))
938
- return true ;
939
- if (Call->getArg (0 )->getType ()->getAsCXXRecordDecl () != LambdaObjTy)
940
- return true ;
941
- LambdaFn = M; // call to PFWG lambda found - record the lambda
942
- return false ; // ... and stop searching
943
- }
944
-
945
- // Returns the captured lambda function or nullptr;
946
- CXXMethodDecl *getLambdaFn () const { return LambdaFn; }
947
-
948
- private:
949
- CXXMethodDecl *LambdaFn;
950
- const CXXRecordDecl *LambdaObjTy;
951
- };
952
-
953
917
class MarkWIScopeFnVisitor : public RecursiveASTVisitor <MarkWIScopeFnVisitor> {
954
918
public:
955
919
MarkWIScopeFnVisitor (ASTContext &Ctx) : Ctx(Ctx) {}
@@ -2541,10 +2505,16 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
2541
2505
void markParallelWorkItemCalls () {
2542
2506
if (getKernelInvocationKind (KernelCallerFunc) ==
2543
2507
InvokeParallelForWorkGroup) {
2544
- FindPFWGLambdaFnVisitor V (KernelObj);
2545
- V.TraverseStmt (KernelCallerFunc->getBody ());
2546
- CXXMethodDecl *WGLambdaFn = V.getLambdaFn ();
2547
- assert (WGLambdaFn && " PFWG lambda not found" );
2508
+ // Fetch the kernel object and the associated call operator
2509
+ // (of either the lambda or the function object).
2510
+ CXXRecordDecl *KernelObj =
2511
+ GetSYCLKernelObjectType (KernelCallerFunc)->getAsCXXRecordDecl ();
2512
+ CXXMethodDecl *WGLambdaFn = nullptr ;
2513
+ if (KernelObj->isLambda ())
2514
+ WGLambdaFn = KernelObj->getLambdaCallOperator ();
2515
+ else
2516
+ WGLambdaFn = getOperatorParens (KernelObj);
2517
+ assert (WGLambdaFn && " non callable object is passed as kernel obj" );
2548
2518
// Mark the function that it "works" in a work group scope:
2549
2519
// NOTE: In case of parallel_for_work_item the marker call itself is
2550
2520
// marked with work item scope attribute, here the '()' operator of the
0 commit comments