Skip to content

[SYCL] Refactor processing of parallel_for_work_group constructs #5918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 5, 2022
50 changes: 10 additions & 40 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -914,42 +914,6 @@ class KernelBodyTransform : public TreeTransform<KernelBodyTransform> {
Sema &SemaRef;
};

// Searches for a call to PFWG lambda function and captures it.
class FindPFWGLambdaFnVisitor
: public RecursiveASTVisitor<FindPFWGLambdaFnVisitor> {
public:
// LambdaObjTy - lambda type of the PFWG lambda object
FindPFWGLambdaFnVisitor(const CXXRecordDecl *LambdaObjTy)
: LambdaFn(nullptr), LambdaObjTy(LambdaObjTy) {}

bool VisitCallExpr(CallExpr *Call) {
auto *M = dyn_cast<CXXMethodDecl>(Call->getDirectCallee());
if (!M || (M->getOverloadedOperator() != OO_Call))
return true;

unsigned int NumPFWGLambdaArgs =
M->getNumParams() + 1; // group, optional kernel_handler and lambda obj
if (Call->getNumArgs() != NumPFWGLambdaArgs)
return true;
if (!Util::isSyclType(Call->getArg(1)->getType(), "group", true /*Tmpl*/))
return true;
if ((Call->getNumArgs() > 2) &&
!Util::isSyclKernelHandlerType(Call->getArg(2)->getType()))
return true;
if (Call->getArg(0)->getType()->getAsCXXRecordDecl() != LambdaObjTy)
return true;
LambdaFn = M; // call to PFWG lambda found - record the lambda
return false; // ... and stop searching
}

// Returns the captured lambda function or nullptr;
CXXMethodDecl *getLambdaFn() const { return LambdaFn; }

private:
CXXMethodDecl *LambdaFn;
const CXXRecordDecl *LambdaObjTy;
};

class MarkWIScopeFnVisitor : public RecursiveASTVisitor<MarkWIScopeFnVisitor> {
public:
MarkWIScopeFnVisitor(ASTContext &Ctx) : Ctx(Ctx) {}
Expand Down Expand Up @@ -2541,10 +2505,16 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
void markParallelWorkItemCalls() {
if (getKernelInvocationKind(KernelCallerFunc) ==
InvokeParallelForWorkGroup) {
FindPFWGLambdaFnVisitor V(KernelObj);
V.TraverseStmt(KernelCallerFunc->getBody());
CXXMethodDecl *WGLambdaFn = V.getLambdaFn();
assert(WGLambdaFn && "PFWG lambda not found");
// Fetch the kernel object and the associated call operator
// (of either the lambda or the function object).
CXXRecordDecl *KernelObj =
GetSYCLKernelObjectType(KernelCallerFunc)->getAsCXXRecordDecl();
CXXMethodDecl *WGLambdaFn = nullptr;
if (KernelObj->isLambda())
WGLambdaFn = KernelObj->getLambdaCallOperator();
else
WGLambdaFn = getOperatorParens(KernelObj);
assert(WGLambdaFn && "non callable object is passed as kernel obj");
// Mark the function that it "works" in a work group scope:
// NOTE: In case of parallel_for_work_item the marker call itself is
// marked with work item scope attribute, here the '()' operator of the
Expand Down
12 changes: 12 additions & 0 deletions clang/test/SemaSYCL/sycl-pfwg-invalid-code.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: %clang_cc1 -fsycl-is-device %s -verify

// Tests that the compiler does not crash (due to a triggered assertion)
// if definition of kernel_parallel_for_work_group is invalid.
template <typename, typename, typename K>
__attribute__((sycl_kernel)) void kernel_parallel_for_work_group(const K &) {
unknown(); // expected-error{{use of undeclared identifier 'unknown'}}
}
void foo() {
auto lambda = [] {};
kernel_parallel_for_work_group<int, int>(lambda);
}