@@ -56,7 +56,7 @@ enum KernelInvocationKind {
56
56
57
57
const static std::string InitMethodName = " __init" ;
58
58
const static std::string FinalizeMethodName = " __finalize" ;
59
- constexpr unsigned GPUMaxKernelArgsNum = 2000 ;
59
+ constexpr unsigned GPUMaxKernelArgsSize = 2048 ;
60
60
61
61
namespace {
62
62
@@ -1656,32 +1656,35 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
1656
1656
using SyclKernelFieldHandler::leaveStruct;
1657
1657
};
1658
1658
1659
- class SyclKernelNumArgsChecker : public SyclKernelFieldHandler {
1659
+ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler {
1660
1660
SourceLocation KernelLoc;
1661
- unsigned NumOfParams = 0 ;
1661
+ unsigned SizeOfParams = 0 ;
1662
+
1663
+ void addParam (QualType ArgTy) {
1664
+ SizeOfParams +=
1665
+ SemaRef.getASTContext ().getTypeSizeInChars (ArgTy).getQuantity ();
1666
+ }
1662
1667
1663
1668
bool handleSpecialType (QualType FieldTy) {
1664
1669
const CXXRecordDecl *RecordDecl = FieldTy->getAsCXXRecordDecl ();
1665
1670
assert (RecordDecl && " The accessor/sampler must be a RecordDecl" );
1666
1671
CXXMethodDecl *InitMethod = getMethodByName (RecordDecl, InitMethodName);
1667
1672
assert (InitMethod && " The accessor/sampler must have the __init method" );
1668
- NumOfParams += InitMethod->getNumParams ();
1673
+ for (const ParmVarDecl *Param : InitMethod->parameters ())
1674
+ addParam (Param->getType ());
1669
1675
return true ;
1670
1676
}
1671
1677
1672
1678
public:
1673
- SyclKernelNumArgsChecker (Sema &S, SourceLocation Loc)
1679
+ SyclKernelArgsSizeChecker (Sema &S, SourceLocation Loc)
1674
1680
: SyclKernelFieldHandler(S), KernelLoc(Loc) {}
1675
1681
1676
- ~SyclKernelNumArgsChecker () {
1682
+ ~SyclKernelArgsSizeChecker () {
1677
1683
if (SemaRef.Context .getTargetInfo ().getTriple ().getSubArch () ==
1678
- llvm::Triple::SPIRSubArch_gen) {
1679
- if (NumOfParams > GPUMaxKernelArgsNum) {
1680
- SemaRef.Diag (KernelLoc, diag::warn_sycl_kernel_too_many_args)
1681
- << NumOfParams << GPUMaxKernelArgsNum;
1682
- SemaRef.Diag (KernelLoc, diag::note_sycl_kernel_args_count);
1683
- }
1684
- }
1684
+ llvm::Triple::SPIRSubArch_gen)
1685
+ if (SizeOfParams > GPUMaxKernelArgsSize)
1686
+ SemaRef.Diag (KernelLoc, diag::warn_sycl_kernel_too_big_args)
1687
+ << SizeOfParams << GPUMaxKernelArgsSize;
1685
1688
}
1686
1689
1687
1690
bool handleSyclAccessorType (FieldDecl *FD, QualType FieldTy) final {
@@ -1703,12 +1706,12 @@ class SyclKernelNumArgsChecker : public SyclKernelFieldHandler {
1703
1706
}
1704
1707
1705
1708
bool handlePointerType (FieldDecl *FD, QualType FieldTy) final {
1706
- NumOfParams++ ;
1709
+ addParam (FieldTy) ;
1707
1710
return true ;
1708
1711
}
1709
1712
1710
1713
bool handleScalarType (FieldDecl *FD, QualType FieldTy) final {
1711
- NumOfParams++ ;
1714
+ addParam (FieldTy) ;
1712
1715
return true ;
1713
1716
}
1714
1717
@@ -1717,17 +1720,17 @@ class SyclKernelNumArgsChecker : public SyclKernelFieldHandler {
1717
1720
}
1718
1721
1719
1722
bool handleSyclHalfType (FieldDecl *FD, QualType FieldTy) final {
1720
- NumOfParams++ ;
1723
+ addParam (FieldTy) ;
1721
1724
return true ;
1722
1725
}
1723
1726
1724
1727
bool handleSyclStreamType (FieldDecl *FD, QualType FieldTy) final {
1725
- NumOfParams++ ;
1728
+ addParam (FieldTy) ;
1726
1729
return true ;
1727
1730
}
1728
1731
bool handleSyclStreamType (const CXXRecordDecl *, const CXXBaseSpecifier &,
1729
1732
QualType FieldTy) final {
1730
- NumOfParams++ ;
1733
+ addParam (FieldTy) ;
1731
1734
return true ;
1732
1735
}
1733
1736
using SyclKernelFieldHandler::handleSyclHalfType;
@@ -2468,7 +2471,7 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
2468
2471
2469
2472
SyclKernelFieldChecker FieldChecker (*this );
2470
2473
SyclKernelUnionChecker UnionChecker (*this );
2471
- SyclKernelNumArgsChecker NumArgsChecker (*this , Args[0 ]->getExprLoc ());
2474
+ SyclKernelArgsSizeChecker ArgsSizeChecker (*this , Args[0 ]->getExprLoc ());
2472
2475
// check that calling kernel conforms to spec
2473
2476
QualType KernelParamTy = KernelFunc->getParamDecl (0 )->getType ();
2474
2477
if (KernelParamTy->isReferenceType ()) {
@@ -2488,9 +2491,9 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
2488
2491
KernelObjVisitor Visitor{*this };
2489
2492
DiagnosingSYCLKernel = true ;
2490
2493
Visitor.VisitRecordBases (KernelObj, FieldChecker, UnionChecker,
2491
- NumArgsChecker );
2494
+ ArgsSizeChecker );
2492
2495
Visitor.VisitRecordFields (KernelObj, FieldChecker, UnionChecker,
2493
- NumArgsChecker );
2496
+ ArgsSizeChecker );
2494
2497
DiagnosingSYCLKernel = false ;
2495
2498
if (!FieldChecker.isValid () || !UnionChecker.isValid ())
2496
2499
KernelFunc->setInvalidDecl ();
0 commit comments