Skip to content

Commit 11f53ad

Browse files
[SYCL] Add support for templated call operator in functors (#7970)
This PR enables kernels to be defined as named function object types, where the operator() member function is templated. Example: ``` class FunctorWithCallOpTemplated { int x; public: template <int x = 0> void operator()() const {} }; q.submit([&](sycl::handler &cgh) { cgh.single_task(FunctorWithCallOpTemplated{}); }); ``` --------- Co-authored-by: premanandrao <[email protected]>
1 parent e85b51f commit 11f53ad

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -991,12 +991,26 @@ static QualType GetSYCLKernelObjectType(const FunctionDecl *KernelCaller) {
991991
return KernelParamTy.getUnqualifiedType();
992992
}
993993

994-
static CXXMethodDecl *getOperatorParens(const CXXRecordDecl *Rec) {
995-
for (auto *MD : Rec->methods()) {
996-
if (MD->getOverloadedOperator() == OO_Call)
997-
return MD;
998-
}
999-
return nullptr;
994+
// Get the call operator function associated with the function object
995+
// for both templated and non-templated operator()().
996+
997+
static CXXMethodDecl *getFunctorCallOperator(const CXXRecordDecl *RD) {
998+
DeclarationName Name =
999+
RD->getASTContext().DeclarationNames.getCXXOperatorName(OO_Call);
1000+
DeclContext::lookup_result Calls = RD->lookup(Name);
1001+
1002+
if (Calls.empty())
1003+
return nullptr;
1004+
1005+
NamedDecl *CallOp = Calls.front();
1006+
1007+
if (CallOp == nullptr)
1008+
return nullptr;
1009+
1010+
if (const auto *CallOpTmpl = dyn_cast<FunctionTemplateDecl>(CallOp))
1011+
return cast<CXXMethodDecl>(CallOpTmpl->getTemplatedDecl());
1012+
1013+
return cast<CXXMethodDecl>(CallOp);
10001014
}
10011015

10021016
// Fetch the associated call operator of the kernel object
@@ -1009,7 +1023,7 @@ GetCallOperatorOfKernelObject(const CXXRecordDecl *KernelObjType) {
10091023
if (KernelObjType->isLambda())
10101024
CallOperator = KernelObjType->getLambdaCallOperator();
10111025
else
1012-
CallOperator = getOperatorParens(KernelObjType);
1026+
CallOperator = getFunctorCallOperator(KernelObjType);
10131027
return CallOperator;
10141028
}
10151029

@@ -2802,7 +2816,7 @@ class SyclOptReportCreator : public SyclKernelFieldHandler {
28022816
};
28032817

28042818
static bool isESIMDKernelType(const CXXRecordDecl *KernelObjType) {
2805-
const CXXMethodDecl *OpParens = getOperatorParens(KernelObjType);
2819+
const CXXMethodDecl *OpParens = getFunctorCallOperator(KernelObjType);
28062820
return (OpParens != nullptr) && OpParens->hasAttr<SYCLSimdAttr>();
28072821
}
28082822

@@ -4041,7 +4055,7 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc,
40414055
// kernel to wrapped kernel.
40424056
void Sema::copySYCLKernelAttrs(const CXXRecordDecl *KernelObj) {
40434057
// Get the operator() function of the wrapper.
4044-
CXXMethodDecl *OpParens = getOperatorParens(KernelObj);
4058+
CXXMethodDecl *OpParens = getFunctorCallOperator(KernelObj);
40454059
assert(OpParens && "invalid kernel object");
40464060

40474061
typedef std::pair<FunctionDecl *, FunctionDecl *> ChildParentPair;

clang/test/SemaSYCL/undefined-functor.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ class FunctorWithCallOpDefined {
1818
void operator()() const {}
1919
};
2020

21+
class FunctorWithCallOpTemplated {
22+
public:
23+
template <int x = 0>
24+
void operator()() const {}
25+
};
26+
2127
int main() {
2228

2329
q.submit([&](sycl::handler &cgh) {
@@ -37,4 +43,8 @@ int main() {
3743
cgh.single_task(FunctorWithCallOpDefined{});
3844
});
3945

46+
q.submit([&](sycl::handler &cgh) {
47+
cgh.single_task(FunctorWithCallOpTemplated{});
48+
});
49+
4050
}

0 commit comments

Comments
 (0)