@@ -99,10 +99,23 @@ class Util {
99
99
// / \param Tmpl whether the class is template instantiation or simple record
100
100
static bool isSyclType (const QualType &Ty, StringRef Name, bool Tmpl = false );
101
101
102
+ // / Checks whether given function is a standard SYCL API function with given
103
+ // / name.
104
+ // / \param FD the function being checked.
105
+ // / \param Name the function name to be checked against.
106
+ static bool isSyclFunction (const FunctionDecl *FD, StringRef Name);
107
+
102
108
// / Checks whether given clang type is a full specialization of the SYCL
103
109
// / specialization constant class.
104
110
static bool isSyclSpecConstantType (const QualType &Ty);
105
111
112
+ // Checks declaration context hierarchy.
113
+ // / \param DC the context of the item to be checked.
114
+ // / \param Scopes the declaration scopes leading from the item context to the
115
+ // / translation unit (excluding the latter)
116
+ static bool matchContext (const DeclContext *DC,
117
+ ArrayRef<Util::DeclContextDesc> Scopes);
118
+
106
119
// / Checks whether given clang type is declared in the given hierarchy of
107
120
// / declaration contexts.
108
121
// / \param Ty the clang type being checked
@@ -165,38 +178,14 @@ static bool IsSyclMathFunc(unsigned BuiltinID) {
165
178
case Builtin::BI__builtin_truncl:
166
179
case Builtin::BIlroundl:
167
180
case Builtin::BI__builtin_lroundl:
168
- case Builtin::BIcopysign:
169
- case Builtin::BI__builtin_copysign:
170
- case Builtin::BIfloor:
171
- case Builtin::BI__builtin_floor:
172
181
case Builtin::BIfmax:
173
182
case Builtin::BI__builtin_fmax:
174
183
case Builtin::BIfmin:
175
184
case Builtin::BI__builtin_fmin:
176
- case Builtin::BInearbyint:
177
- case Builtin::BI__builtin_nearbyint:
178
- case Builtin::BIrint:
179
- case Builtin::BI__builtin_rint:
180
- case Builtin::BIround:
181
- case Builtin::BI__builtin_round:
182
- case Builtin::BItrunc:
183
- case Builtin::BI__builtin_trunc:
184
- case Builtin::BIcopysignf:
185
- case Builtin::BI__builtin_copysignf:
186
- case Builtin::BIfloorf:
187
- case Builtin::BI__builtin_floorf:
188
185
case Builtin::BIfmaxf:
189
186
case Builtin::BI__builtin_fmaxf:
190
187
case Builtin::BIfminf:
191
188
case Builtin::BI__builtin_fminf:
192
- case Builtin::BInearbyintf:
193
- case Builtin::BI__builtin_nearbyintf:
194
- case Builtin::BIrintf:
195
- case Builtin::BI__builtin_rintf:
196
- case Builtin::BIroundf:
197
- case Builtin::BI__builtin_roundf:
198
- case Builtin::BItruncf:
199
- case Builtin::BI__builtin_truncf:
200
189
case Builtin::BIlroundf:
201
190
case Builtin::BI__builtin_lroundf:
202
191
case Builtin::BI__builtin_fpclassify:
@@ -511,6 +500,21 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
511
500
FunctionDecl *FD = WorkList.back ().first ;
512
501
FunctionDecl *ParentFD = WorkList.back ().second ;
513
502
503
+ // To implement rounding-up of a parallel-for range the
504
+ // SYCL header implementation modifies the kernel call like this:
505
+ // auto Wrapper = [=](TransformedArgType Arg) {
506
+ // if (Arg[0] >= NumWorkItems[0])
507
+ // return;
508
+ // Arg.set_allowed_range(NumWorkItems);
509
+ // KernelFunc(Arg);
510
+ // };
511
+ //
512
+ // This transformation leads to a condition where a kernel body
513
+ // function becomes callable from a new kernel body function.
514
+ // Hence this test.
515
+ if ((ParentFD == KernelBody) && isSYCLKernelBodyFunction (FD))
516
+ KernelBody = FD;
517
+
514
518
if ((ParentFD == SYCLKernel) && isSYCLKernelBodyFunction (FD)) {
515
519
assert (!KernelBody && " inconsistent call graph - only one kernel body "
516
520
" function can be called" );
@@ -2691,15 +2695,63 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
2691
2695
return !SemaRef.getASTContext ().hasSameType (FD->getType (), Ty);
2692
2696
}
2693
2697
2698
+ // Sets a flag if the kernel is a parallel_for that calls the
2699
+ // free function API "this_item".
2700
+ void setThisItemIsCalled (const CXXRecordDecl *KernelObj,
2701
+ FunctionDecl *KernelFunc) {
2702
+ if (getKernelInvocationKind (KernelFunc) != InvokeParallelFor)
2703
+ return ;
2704
+
2705
+ const CXXMethodDecl *WGLambdaFn = getOperatorParens (KernelObj);
2706
+ if (!WGLambdaFn)
2707
+ return ;
2708
+
2709
+ // The call graph for this translation unit.
2710
+ CallGraph SYCLCG;
2711
+ SYCLCG.addToCallGraph (SemaRef.getASTContext ().getTranslationUnitDecl ());
2712
+ using ChildParentPair =
2713
+ std::pair<const FunctionDecl *, const FunctionDecl *>;
2714
+ llvm::SmallPtrSet<const FunctionDecl *, 16 > Visited;
2715
+ llvm::SmallVector<ChildParentPair, 16 > WorkList;
2716
+ WorkList.push_back ({WGLambdaFn, nullptr });
2717
+
2718
+ while (!WorkList.empty ()) {
2719
+ const FunctionDecl *FD = WorkList.back ().first ;
2720
+ WorkList.pop_back ();
2721
+ if (!Visited.insert (FD).second )
2722
+ continue ; // We've already seen this Decl
2723
+
2724
+ // Check whether this call is to sycl::this_item().
2725
+ if (Util::isSyclFunction (FD, " this_item" )) {
2726
+ Header.setCallsThisItem (true );
2727
+ return ;
2728
+ }
2729
+
2730
+ CallGraphNode *N = SYCLCG.getNode (FD);
2731
+ if (!N)
2732
+ continue ;
2733
+
2734
+ for (const CallGraphNode *CI : *N) {
2735
+ if (auto *Callee = dyn_cast<FunctionDecl>(CI->getDecl ())) {
2736
+ Callee = Callee->getMostRecentDecl ();
2737
+ if (!Visited.count (Callee))
2738
+ WorkList.push_back ({Callee, FD});
2739
+ }
2740
+ }
2741
+ }
2742
+ }
2743
+
2694
2744
public:
2695
2745
static constexpr const bool VisitInsideSimpleContainers = false ;
2696
2746
SyclKernelIntHeaderCreator (Sema &S, SYCLIntegrationHeader &H,
2697
2747
const CXXRecordDecl *KernelObj, QualType NameType,
2698
- StringRef Name, StringRef StableName)
2748
+ StringRef Name, StringRef StableName,
2749
+ FunctionDecl *KernelFunc)
2699
2750
: SyclKernelFieldHandler(S), Header(H) {
2700
2751
bool IsSIMDKernel = isESIMDKernelType (KernelObj);
2701
2752
Header.startKernel (Name, NameType, StableName, KernelObj->getLocation (),
2702
2753
IsSIMDKernel);
2754
+ setThisItemIsCalled (KernelObj, KernelFunc);
2703
2755
}
2704
2756
2705
2757
bool handleSyclAccessorType (const CXXRecordDecl *RD,
@@ -3147,7 +3199,7 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
3147
3199
SyclKernelIntHeaderCreator int_header (
3148
3200
*this , getSyclIntegrationHeader (), KernelObj,
3149
3201
calculateKernelNameType (Context, KernelCallerFunc), KernelName,
3150
- StableName);
3202
+ StableName, KernelCallerFunc );
3151
3203
3152
3204
KernelObjVisitor Visitor{*this };
3153
3205
Visitor.VisitRecordBases (KernelObj, kernel_decl, kernel_body, int_header);
@@ -3866,6 +3918,9 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
3866
3918
O << " __SYCL_DLL_LOCAL\n " ;
3867
3919
O << " static constexpr bool isESIMD() { return " << K.IsESIMDKernel
3868
3920
<< " ; }\n " ;
3921
+ O << " __SYCL_DLL_LOCAL\n " ;
3922
+ O << " static constexpr bool callsThisItem() { return " ;
3923
+ O << K.CallsThisItem << " ; }\n " ;
3869
3924
O << " };\n " ;
3870
3925
CurStart += N;
3871
3926
}
@@ -3924,6 +3979,12 @@ void SYCLIntegrationHeader::addSpecConstant(StringRef IDName, QualType IDType) {
3924
3979
SpecConsts.emplace_back (std::make_pair (IDType, IDName.str ()));
3925
3980
}
3926
3981
3982
+ void SYCLIntegrationHeader::setCallsThisItem (bool B) {
3983
+ KernelDesc *K = getCurKernelDesc ();
3984
+ assert (K && " no kernels" );
3985
+ K->CallsThisItem = B;
3986
+ }
3987
+
3927
3988
SYCLIntegrationHeader::SYCLIntegrationHeader (DiagnosticsEngine &_Diag,
3928
3989
bool _UnnamedLambdaSupport,
3929
3990
Sema &_S)
@@ -3991,6 +4052,21 @@ bool Util::isSyclType(const QualType &Ty, StringRef Name, bool Tmpl) {
3991
4052
return matchQualifiedTypeName (Ty, Scopes);
3992
4053
}
3993
4054
4055
+ bool Util::isSyclFunction (const FunctionDecl *FD, StringRef Name) {
4056
+ if (!FD->isFunctionOrMethod () || !FD->getIdentifier () ||
4057
+ FD->getName ().empty () || Name != FD->getName ())
4058
+ return false ;
4059
+
4060
+ const DeclContext *DC = FD->getDeclContext ();
4061
+ if (DC->isTranslationUnit ())
4062
+ return false ;
4063
+
4064
+ std::array<DeclContextDesc, 2 > Scopes = {
4065
+ Util::DeclContextDesc{clang::Decl::Kind::Namespace, " cl" },
4066
+ Util::DeclContextDesc{clang::Decl::Kind::Namespace, " sycl" }};
4067
+ return matchContext (DC, Scopes);
4068
+ }
4069
+
3994
4070
bool Util::isAccessorPropertyListType (const QualType &Ty) {
3995
4071
const StringRef &Name = " accessor_property_list" ;
3996
4072
std::array<DeclContextDesc, 4 > Scopes = {
@@ -4001,21 +4077,15 @@ bool Util::isAccessorPropertyListType(const QualType &Ty) {
4001
4077
return matchQualifiedTypeName (Ty, Scopes);
4002
4078
}
4003
4079
4004
- bool Util::matchQualifiedTypeName (const QualType &Ty ,
4005
- ArrayRef<Util::DeclContextDesc> Scopes) {
4006
- // The idea: check the declaration context chain starting from the type
4080
+ bool Util::matchContext (const DeclContext *Ctx ,
4081
+ ArrayRef<Util::DeclContextDesc> Scopes) {
4082
+ // The idea: check the declaration context chain starting from the item
4007
4083
// itself. At each step check the context is of expected kind
4008
4084
// (namespace) and name.
4009
- const CXXRecordDecl *RecTy = Ty->getAsCXXRecordDecl ();
4010
-
4011
- if (!RecTy)
4012
- return false ; // only classes/structs supported
4013
- const auto *Ctx = cast<DeclContext>(RecTy);
4014
4085
StringRef Name = " " ;
4015
4086
4016
4087
for (const auto &Scope : llvm::reverse (Scopes)) {
4017
4088
clang::Decl::Kind DK = Ctx->getDeclKind ();
4018
-
4019
4089
if (DK != Scope.first )
4020
4090
return false ;
4021
4091
@@ -4029,11 +4099,21 @@ bool Util::matchQualifiedTypeName(const QualType &Ty,
4029
4099
Name = cast<NamespaceDecl>(Ctx)->getName ();
4030
4100
break ;
4031
4101
default :
4032
- llvm_unreachable (" matchQualifiedTypeName : decl kind not supported" );
4102
+ llvm_unreachable (" matchContext : decl kind not supported" );
4033
4103
}
4034
4104
if (Name != Scope.second )
4035
4105
return false ;
4036
4106
Ctx = Ctx->getParent ();
4037
4107
}
4038
4108
return Ctx->isTranslationUnit ();
4039
4109
}
4110
+
4111
+ bool Util::matchQualifiedTypeName (const QualType &Ty,
4112
+ ArrayRef<Util::DeclContextDesc> Scopes) {
4113
+ const CXXRecordDecl *RecTy = Ty->getAsCXXRecordDecl ();
4114
+
4115
+ if (!RecTy)
4116
+ return false ; // only classes/structs supported
4117
+ const auto *Ctx = cast<DeclContext>(RecTy);
4118
+ return Util::matchContext (Ctx, Scopes);
4119
+ }
0 commit comments