@@ -56,6 +56,7 @@ enum KernelInvocationKind {
56
56
57
57
const static std::string InitMethodName = " __init" ;
58
58
const static std::string FinalizeMethodName = " __finalize" ;
59
+ constexpr unsigned GPUMaxKernelArgsNum = 2000 ;
59
60
60
61
namespace {
61
62
@@ -1657,6 +1658,83 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
1657
1658
using SyclKernelFieldHandler::leaveStruct;
1658
1659
};
1659
1660
1661
+ class SyclKernelNumArgsChecker : public SyclKernelFieldHandler {
1662
+ SourceLocation KernelLoc;
1663
+ unsigned NumOfParams = 0 ;
1664
+
1665
+ bool handleSpecialType (QualType FieldTy) {
1666
+ const CXXRecordDecl *RecordDecl = FieldTy->getAsCXXRecordDecl ();
1667
+ assert (RecordDecl && " The accessor/sampler must be a RecordDecl" );
1668
+ CXXMethodDecl *InitMethod = getMethodByName (RecordDecl, InitMethodName);
1669
+ assert (InitMethod && " The accessor/sampler must have the __init method" );
1670
+ NumOfParams += InitMethod->getNumParams ();
1671
+ return true ;
1672
+ }
1673
+
1674
+ public:
1675
+ SyclKernelNumArgsChecker (Sema &S, SourceLocation Loc)
1676
+ : SyclKernelFieldHandler(S), KernelLoc(Loc) {}
1677
+
1678
+ ~SyclKernelNumArgsChecker () {
1679
+ if (SemaRef.Context .getTargetInfo ().getTriple ().getSubArch () ==
1680
+ llvm::Triple::SPIRSubArch_gen) {
1681
+ if (NumOfParams > GPUMaxKernelArgsNum) {
1682
+ SemaRef.Diag (KernelLoc, diag::warn_sycl_kernel_too_many_args)
1683
+ << NumOfParams << GPUMaxKernelArgsNum;
1684
+ SemaRef.Diag (KernelLoc, diag::note_sycl_kernel_args_count);
1685
+ }
1686
+ }
1687
+ }
1688
+
1689
+ bool handleSyclAccessorType (FieldDecl *FD, QualType FieldTy) final {
1690
+ return handleSpecialType (FieldTy);
1691
+ }
1692
+
1693
+ bool handleSyclAccessorType (const CXXRecordDecl *, const CXXBaseSpecifier &,
1694
+ QualType FieldTy) final {
1695
+ return handleSpecialType (FieldTy);
1696
+ }
1697
+
1698
+ bool handleSyclSamplerType (FieldDecl *FD, QualType FieldTy) final {
1699
+ return handleSpecialType (FieldTy);
1700
+ }
1701
+
1702
+ bool handleSyclSamplerType (const CXXRecordDecl *, const CXXBaseSpecifier &BS,
1703
+ QualType FieldTy) final {
1704
+ return handleSpecialType (FieldTy);
1705
+ }
1706
+
1707
+ bool handlePointerType (FieldDecl *FD, QualType FieldTy) final {
1708
+ NumOfParams++;
1709
+ return true ;
1710
+ }
1711
+
1712
+ bool handleScalarType (FieldDecl *FD, QualType FieldTy) final {
1713
+ NumOfParams++;
1714
+ return true ;
1715
+ }
1716
+
1717
+ bool handleUnionType (FieldDecl *FD, QualType FieldTy) final {
1718
+ return handleScalarType (FD, FieldTy);
1719
+ }
1720
+
1721
+ bool handleSyclHalfType (FieldDecl *FD, QualType FieldTy) final {
1722
+ NumOfParams++;
1723
+ return true ;
1724
+ }
1725
+
1726
+ bool handleSyclStreamType (FieldDecl *FD, QualType FieldTy) final {
1727
+ NumOfParams++;
1728
+ return true ;
1729
+ }
1730
+ bool handleSyclStreamType (const CXXRecordDecl *, const CXXBaseSpecifier &,
1731
+ QualType FieldTy) final {
1732
+ NumOfParams++;
1733
+ return true ;
1734
+ }
1735
+ using SyclKernelFieldHandler::handleSyclHalfType;
1736
+ };
1737
+
1660
1738
class SyclKernelBodyCreator : public SyclKernelFieldHandler {
1661
1739
SyclKernelDeclCreator &DeclCreator;
1662
1740
llvm::SmallVector<Stmt *, 16 > BodyStmts;
@@ -2351,6 +2429,7 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
2351
2429
2352
2430
SyclKernelFieldChecker FieldChecker (*this );
2353
2431
SyclKernelUnionChecker UnionChecker (*this );
2432
+ SyclKernelNumArgsChecker NumArgsChecker (*this , Args[0 ]->getExprLoc ());
2354
2433
// check that calling kernel conforms to spec
2355
2434
QualType KernelParamTy = KernelFunc->getParamDecl (0 )->getType ();
2356
2435
if (KernelParamTy->isReferenceType ()) {
@@ -2365,8 +2444,10 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
2365
2444
2366
2445
KernelObjVisitor Visitor{*this };
2367
2446
DiagnosingSYCLKernel = true ;
2368
- Visitor.VisitRecordBases (KernelObj, FieldChecker, UnionChecker);
2369
- Visitor.VisitRecordFields (KernelObj, FieldChecker, UnionChecker);
2447
+ Visitor.VisitRecordBases (KernelObj, FieldChecker, UnionChecker,
2448
+ NumArgsChecker);
2449
+ Visitor.VisitRecordFields (KernelObj, FieldChecker, UnionChecker,
2450
+ NumArgsChecker);
2370
2451
DiagnosingSYCLKernel = false ;
2371
2452
if (!FieldChecker.isValid () || !UnionChecker.isValid ())
2372
2453
KernelFunc->setInvalidDecl ();
0 commit comments