Skip to content

Commit 496f9a0

Browse files
authored
[SYCL] Copy attributes of parallel_for kernels to the wrapped versions when rounding up the range (#3154)
Parallel_for range rounding creates a wrapped version of the original kernel. Kernel attributes need to be copied from original kernel to new wrapped kernel. Signed-off-by: rdeodhar <[email protected]>
1 parent 5af118a commit 496f9a0

18 files changed

+277
-120
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3192,6 +3192,9 @@ def warn_dllimport_dropped_from_inline_function : Warning<
31923192
InGroup<IgnoredAttributes>;
31933193
def warn_attribute_ignored : Warning<"%0 attribute ignored">,
31943194
InGroup<IgnoredAttributes>;
3195+
def warn_attribute_on_direct_kernel_callee_only : Warning<"%0 attribute allowed"
3196+
" only on a function directly called from a SYCL kernel function; attribute ignored">,
3197+
InGroup<IgnoredAttributes>;
31953198
def warn_nothrow_attribute_ignored : Warning<"'nothrow' attribute conflicts with"
31963199
" exception specification; attribute ignored">,
31973200
InGroup<IgnoredAttributes>;

clang/include/clang/Sema/Sema.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13012,6 +13012,7 @@ class Sema final {
1301213012

1301313013
bool isKnownGoodSYCLDecl(const Decl *D);
1301413014
void checkSYCLDeviceVarDecl(VarDecl *Var);
13015+
void copySYCLKernelAttrs(const CXXRecordDecl *KernelObj);
1301513016
void ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc, MangleContext &MC);
1301613017
void MarkDevice();
1301713018

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 100 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,36 @@ static int64_t getIntExprValue(const Expr *E, ASTContext &Ctx) {
306306
return E->getIntegerConstantExpr(Ctx)->getSExtValue();
307307
}
308308

309+
// Collect function attributes related to SYCL
310+
static void collectSYCLAttributes(Sema &S, FunctionDecl *FD,
311+
llvm::SmallVector<Attr *, 4> &Attrs,
312+
bool DirectlyCalled = true) {
313+
if (!FD->hasAttrs())
314+
return;
315+
316+
llvm::copy_if(FD->getAttrs(), std::back_inserter(Attrs), [](Attr *A) {
317+
return isa<IntelReqdSubGroupSizeAttr, ReqdWorkGroupSizeAttr,
318+
SYCLIntelKernelArgsRestrictAttr, SYCLIntelNumSimdWorkItemsAttr,
319+
SYCLIntelSchedulerTargetFmaxMhzAttr,
320+
SYCLIntelMaxWorkGroupSizeAttr, SYCLIntelMaxGlobalWorkDimAttr,
321+
SYCLIntelNoGlobalWorkOffsetAttr, SYCLSimdAttr>(A);
322+
});
323+
324+
// Allow the kernel attribute "use_stall_enable_clusters" only on lambda
325+
// functions and function objects called directly from a kernel.
326+
// For all other cases, emit a warning and ignore.
327+
if (auto *A = FD->getAttr<SYCLIntelUseStallEnableClustersAttr>()) {
328+
if (DirectlyCalled) {
329+
Attrs.push_back(A);
330+
} else {
331+
S.Diag(A->getLocation(),
332+
diag::warn_attribute_on_direct_kernel_callee_only)
333+
<< A;
334+
FD->dropAttr<SYCLIntelUseStallEnableClustersAttr>();
335+
}
336+
}
337+
}
338+
309339
class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
310340
// Used to keep track of the constexpr depth, so we know whether to skip
311341
// diagnostics.
@@ -477,7 +507,7 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
477507
// Returns the kernel body function found during traversal.
478508
FunctionDecl *
479509
CollectPossibleKernelAttributes(FunctionDecl *SYCLKernel,
480-
llvm::SmallPtrSet<Attr *, 4> &Attrs) {
510+
llvm::SmallVector<Attr *, 4> &Attrs) {
481511
typedef std::pair<FunctionDecl *, FunctionDecl *> ChildParentPair;
482512
llvm::SmallPtrSet<FunctionDecl *, 16> Visited;
483513
llvm::SmallVector<ChildParentPair, 16> WorkList;
@@ -508,55 +538,23 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
508538
"function can be called");
509539
KernelBody = FD;
510540
}
541+
511542
WorkList.pop_back();
512543
if (!Visited.insert(FD).second)
513544
continue; // We've already seen this Decl
514545

515-
if (auto *A = FD->getAttr<IntelReqdSubGroupSizeAttr>())
516-
Attrs.insert(A);
517-
518-
if (auto *A = FD->getAttr<ReqdWorkGroupSizeAttr>())
519-
Attrs.insert(A);
520-
521-
if (auto *A = FD->getAttr<SYCLIntelKernelArgsRestrictAttr>())
522-
Attrs.insert(A);
523-
524-
if (auto *A = FD->getAttr<SYCLIntelNumSimdWorkItemsAttr>())
525-
Attrs.insert(A);
526-
527-
if (auto *A = FD->getAttr<SYCLIntelSchedulerTargetFmaxMhzAttr>())
528-
Attrs.insert(A);
529-
530-
if (auto *A = FD->getAttr<SYCLIntelMaxWorkGroupSizeAttr>())
531-
Attrs.insert(A);
532-
533-
if (auto *A = FD->getAttr<SYCLIntelMaxGlobalWorkDimAttr>())
534-
Attrs.insert(A);
535-
536-
if (auto *A = FD->getAttr<SYCLIntelNoGlobalWorkOffsetAttr>())
537-
Attrs.insert(A);
538-
539-
if (auto *A = FD->getAttr<SYCLSimdAttr>())
540-
Attrs.insert(A);
541-
542-
// Allow the kernel attribute "use_stall_enable_clusters" only on lambda
543-
// functions and function objects that are called directly from a kernel
544-
// (i.e. the one passed to the single_task or parallel_for functions).
545-
// For all other cases, emit a warning and ignore.
546-
if (auto *A = FD->getAttr<SYCLIntelUseStallEnableClustersAttr>()) {
547-
if (ParentFD == SYCLKernel) {
548-
Attrs.insert(A);
549-
} else {
550-
SemaRef.Diag(A->getLocation(), diag::warn_attribute_ignored) << A;
551-
FD->dropAttr<SYCLIntelUseStallEnableClustersAttr>();
552-
}
553-
}
546+
// Gather all attributes of FD that are SYCL related.
547+
// Some attributes are allowed only on lambda functions and function
548+
// objects called directly from a kernel (i.e. the one passed to the
549+
// single_task or parallel_for functions).
550+
bool DirectlyCalled = (ParentFD == SYCLKernel);
551+
collectSYCLAttributes(SemaRef, FD, Attrs, DirectlyCalled);
554552

555553
// Attribute "loop_fuse" can be applied explicitly on kernel function.
556554
// Attribute should not be propagated from device functions to kernel.
557555
if (auto *A = FD->getAttr<SYCLIntelLoopFuseAttr>()) {
558556
if (ParentFD == SYCLKernel) {
559-
Attrs.insert(A);
557+
Attrs.push_back(A);
560558
}
561559
}
562560

@@ -2058,8 +2056,8 @@ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler {
20582056
using SyclKernelFieldHandler::handleSyclHalfType;
20592057
};
20602058

2061-
static const CXXMethodDecl *getOperatorParens(const CXXRecordDecl *Rec) {
2062-
for (const auto *MD : Rec->methods()) {
2059+
static CXXMethodDecl *getOperatorParens(const CXXRecordDecl *Rec) {
2060+
for (auto *MD : Rec->methods()) {
20632061
if (MD->getOverloadedOperator() == OO_Call)
20642062
return MD;
20652063
}
@@ -3149,6 +3147,56 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
31493147
KernelFunc->setInvalidDecl();
31503148
}
31513149

3150+
// For a wrapped parallel_for, copy attributes from original
3151+
// kernel to wrapped kernel.
3152+
void Sema::copySYCLKernelAttrs(const CXXRecordDecl *KernelObj) {
3153+
// Get the operator() function of the wrapper
3154+
CXXMethodDecl *OpParens = getOperatorParens(KernelObj);
3155+
assert(OpParens && "invalid kernel object");
3156+
3157+
typedef std::pair<FunctionDecl *, FunctionDecl *> ChildParentPair;
3158+
llvm::SmallPtrSet<FunctionDecl *, 16> Visited;
3159+
llvm::SmallVector<ChildParentPair, 16> WorkList;
3160+
WorkList.push_back({OpParens, nullptr});
3161+
FunctionDecl *KernelBody = nullptr;
3162+
3163+
CallGraph SYCLCG;
3164+
SYCLCG.addToCallGraph(getASTContext().getTranslationUnitDecl());
3165+
while (!WorkList.empty()) {
3166+
FunctionDecl *FD = WorkList.back().first;
3167+
FunctionDecl *ParentFD = WorkList.back().second;
3168+
3169+
if ((ParentFD == OpParens) && isSYCLKernelBodyFunction(FD)) {
3170+
KernelBody = FD;
3171+
break;
3172+
}
3173+
3174+
WorkList.pop_back();
3175+
if (!Visited.insert(FD).second)
3176+
continue; // We've already seen this Decl
3177+
3178+
CallGraphNode *N = SYCLCG.getNode(FD);
3179+
if (!N)
3180+
continue;
3181+
3182+
for (const CallGraphNode *CI : *N) {
3183+
if (auto *Callee = dyn_cast<FunctionDecl>(CI->getDecl())) {
3184+
Callee = Callee->getMostRecentDecl();
3185+
if (!Visited.count(Callee))
3186+
WorkList.push_back({Callee, FD});
3187+
}
3188+
}
3189+
}
3190+
3191+
assert(KernelBody && "improper parallel_for wrap");
3192+
if (KernelBody) {
3193+
llvm::SmallVector<Attr *, 4> Attrs;
3194+
collectSYCLAttributes(*this, KernelBody, Attrs);
3195+
if (!Attrs.empty())
3196+
llvm::for_each(Attrs, [OpParens](Attr *A) { OpParens->addAttr(A); });
3197+
}
3198+
}
3199+
31523200
// Generates the OpenCL kernel using KernelCallerFunc (kernel caller
31533201
// function) defined is SYCL headers.
31543202
// Generated OpenCL kernel contains the body of the kernel caller function,
@@ -3181,14 +3229,20 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
31813229
if (KernelObj->isInvalidDecl())
31823230
return;
31833231

3184-
bool IsSIMDKernel = isESIMDKernelType(KernelObj);
3185-
31863232
// Calculate both names, since Integration headers need both.
31873233
std::string CalculatedName, StableName;
31883234
std::tie(CalculatedName, StableName) =
31893235
constructKernelName(*this, KernelCallerFunc, MC);
31903236
StringRef KernelName(getLangOpts().SYCLUnnamedLambda ? StableName
31913237
: CalculatedName);
3238+
3239+
// Attributes of a user-written SYCL kernel must be copied to the internally
3240+
// generated alternative kernel, identified by a known string in its name.
3241+
if (StableName.find("__pf_kernel_wrapper") != std::string::npos)
3242+
copySYCLKernelAttrs(KernelObj);
3243+
3244+
bool IsSIMDKernel = isESIMDKernelType(KernelObj);
3245+
31923246
SyclKernelDeclCreator kernel_decl(*this, KernelName, KernelObj->getLocation(),
31933247
KernelCallerFunc->isInlined(),
31943248
IsSIMDKernel);
@@ -3226,7 +3280,7 @@ void Sema::MarkDevice(void) {
32263280
Marker.CollectKernelSet(SYCLKernel, SYCLKernel, VisitedSet);
32273281

32283282
// Let's propagate attributes from device functions to a SYCL kernels
3229-
llvm::SmallPtrSet<Attr *, 4> Attrs;
3283+
llvm::SmallVector<Attr *, 4> Attrs;
32303284
// This function collects all kernel attributes which might be applied to
32313285
// a device functions, but need to be propagated down to callers, i.e.
32323286
// SYCL kernels

clang/test/SemaSYCL/Inputs/sycl.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,24 @@ template <typename Type>
206206
struct get_kernel_name_t<auto_name, Type> {
207207
using name = Type;
208208
};
209+
210+
// Used when parallel_for range is rounded-up.
211+
template <typename Type> class __pf_kernel_wrapper;
212+
213+
template <typename Type> struct get_kernel_wrapper_name_t {
214+
using name =
215+
__pf_kernel_wrapper<typename get_kernel_name_t<auto_name, Type>::name>;
216+
};
217+
209218
#define ATTR_SYCL_KERNEL __attribute__((sycl_kernel))
210219
template <typename KernelName = auto_name, typename KernelType>
211220
ATTR_SYCL_KERNEL void kernel_single_task(const KernelType &kernelFunc) {
212221
kernelFunc();
213222
}
223+
template <typename KernelName = auto_name, typename KernelType>
224+
ATTR_SYCL_KERNEL void kernel_parallel_for(const KernelType &kernelFunc) {
225+
kernelFunc();
226+
}
214227
class handler {
215228
public:
216229
template <typename KernelName = auto_name, typename KernelType>
@@ -220,6 +233,16 @@ class handler {
220233
kernel_single_task<NameT>(kernelFunc);
221234
#else
222235
kernelFunc();
236+
#endif
237+
}
238+
template <typename KernelName = auto_name, typename KernelType>
239+
void parallel_for(const KernelType &kernelObj) {
240+
using NameT = typename get_kernel_name_t<KernelName, KernelType>::name;
241+
using NameWT = typename get_kernel_wrapper_name_t<NameT>::name;
242+
#ifdef __SYCL_DEVICE_ONLY__
243+
kernel_parallel_for<NameT>(kernelObj);
244+
#else
245+
kernelObj();
223246
#endif
224247
}
225248
};

clang/test/SemaSYCL/args-size-overflow.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ queue q;
1111
using Accessor =
1212
accessor<int, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::global_buffer>;
1313
#ifdef SPIR64
14-
// expected-warning@Inputs/sycl.hpp:220 {{size of kernel arguments (7994 bytes) may exceed the supported maximum of 2048 bytes on some devices}}
14+
// expected-warning@Inputs/sycl.hpp:233 {{size of kernel arguments (7994 bytes) may exceed the supported maximum of 2048 bytes on some devices}}
1515
#elif SPIR32
16-
// expected-warning@Inputs/sycl.hpp:220 {{size of kernel arguments (7986 bytes) may exceed the supported maximum of 2048 bytes on some devices}}
16+
// expected-warning@Inputs/sycl.hpp:233 {{size of kernel arguments (7986 bytes) may exceed the supported maximum of 2048 bytes on some devices}}
1717
#endif
1818

1919
void use() {

clang/test/SemaSYCL/deferred-diagnostics-aux-builtin.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ int main(int argc, char **argv) {
1212
_mm_prefetch("test", 8); // expected-error {{argument value 8 is outside the valid range [0, 7]}}
1313

1414
deviceQueue.submit([&](sycl::handler &h) {
15-
// expected-note@Inputs/sycl.hpp:212 {{called by 'kernel_single_task<AName, (lambda}}
15+
// expected-note@Inputs/sycl.hpp:221 {{called by 'kernel_single_task<AName, (lambda}}
1616
h.single_task<class AName>([]() {
1717
_mm_prefetch("test", 4); // expected-error {{builtin is not supported on this target}}
1818
_mm_prefetch("test", 8); // expected-error {{argument value 8 is outside the valid range [0, 7]}} expected-error {{builtin is not supported on this target}}

clang/test/SemaSYCL/deferred-diagnostics-emit.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ template <typename T>
6464
void setup_sycl_operation(const T VA[]) {
6565

6666
deviceQueue.submit([&](sycl::handler &h) {
67-
// expected-note@Inputs/sycl.hpp:212 {{called by 'kernel_single_task<AName, (lambda}}
67+
// expected-note@Inputs/sycl.hpp:221 {{called by 'kernel_single_task<AName, (lambda}}
6868
h.single_task<class AName>([]() {
6969
// ======= Zero Length Arrays Not Allowed in Kernel ==========
7070
// expected-error@+1 {{zero-length arrays are not permitted in C++}}
@@ -156,7 +156,7 @@ int main(int argc, char **argv) {
156156

157157
// --- direct lambda testing ---
158158
deviceQueue.submit([&](sycl::handler &h) {
159-
// expected-note@Inputs/sycl.hpp:212 2 {{called by 'kernel_single_task<AName, (lambda}}
159+
// expected-note@Inputs/sycl.hpp:221 2 {{called by 'kernel_single_task<AName, (lambda}}
160160
h.single_task<class AName>([]() {
161161
// expected-error@+1 {{zero-length arrays are not permitted in C++}}
162162
int BadArray[0];

clang/test/SemaSYCL/float128.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ int main() {
7171
__float128 CapturedToDevice = 1;
7272
host_ok();
7373
deviceQueue.submit([&](sycl::handler &h) {
74-
// expected-note@Inputs/sycl.hpp:212 {{called by 'kernel_single_task<variables, (lambda}}
74+
// expected-note@Inputs/sycl.hpp:221 {{called by 'kernel_single_task<variables, (lambda}}
7575
h.single_task<class variables>([=]() {
7676
// expected-error@+1 {{'__float128' is not supported on this target}}
7777
decltype(CapturedToDevice) D;
@@ -88,7 +88,7 @@ int main() {
8888
});
8989

9090
deviceQueue.submit([&](sycl::handler &h) {
91-
// expected-note@Inputs/sycl.hpp:212 4{{called by 'kernel_single_task<functions, (lambda}}
91+
// expected-note@Inputs/sycl.hpp:221 4{{called by 'kernel_single_task<functions, (lambda}}
9292
h.single_task<class functions>([=]() {
9393
// expected-note@+1 2{{called by 'operator()'}}
9494
usage();
@@ -104,7 +104,7 @@ int main() {
104104
});
105105

106106
deviceQueue.submit([&](sycl::handler &h) {
107-
// expected-note@Inputs/sycl.hpp:212 {{called by 'kernel_single_task<ok, (lambda}}
107+
// expected-note@Inputs/sycl.hpp:221 {{called by 'kernel_single_task<ok, (lambda}}
108108
h.single_task<class ok>([=]() {
109109
// expected-note@+1 3{{used here}}
110110
Z<__float128> S;

clang/test/SemaSYCL/implicit_kernel_type.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ int main() {
2525
queue q;
2626

2727
#if defined(WARN)
28-
// expected-error@Inputs/sycl.hpp:220 {{'InvalidKernelName1' is an invalid kernel name type}}
29-
// expected-note@Inputs/sycl.hpp:220 {{'InvalidKernelName1' should be globally-visible}}
28+
// expected-error@Inputs/sycl.hpp:233 {{'InvalidKernelName1' is an invalid kernel name type}}
29+
// expected-note@Inputs/sycl.hpp:233 {{'InvalidKernelName1' should be globally-visible}}
3030
// expected-note@+8 {{in instantiation of function template specialization}}
3131
#elif defined(ERROR)
32-
// expected-error@Inputs/sycl.hpp:220 {{'InvalidKernelName1' is an invalid kernel name type}}
33-
// expected-note@Inputs/sycl.hpp:220 {{'InvalidKernelName1' should be globally-visible}}
32+
// expected-error@Inputs/sycl.hpp:233 {{'InvalidKernelName1' is an invalid kernel name type}}
33+
// expected-note@Inputs/sycl.hpp:233 {{'InvalidKernelName1' should be globally-visible}}
3434
// expected-note@+4 {{in instantiation of function template specialization}}
3535
#endif
3636
class InvalidKernelName1 {};
@@ -39,9 +39,9 @@ int main() {
3939
});
4040

4141
#if defined(WARN)
42-
// expected-warning@Inputs/sycl.hpp:220 {{SYCL 1.2.1 specification requires an explicit forward declaration for a kernel type name; your program may not be portable}}
42+
// expected-warning@Inputs/sycl.hpp:233 {{SYCL 1.2.1 specification requires an explicit forward declaration for a kernel type name; your program may not be portable}}
4343
#elif defined(ERROR)
44-
// expected-error@Inputs/sycl.hpp:220 {{SYCL 1.2.1 specification requires an explicit forward declaration for a kernel type name; your program may not be portable}}
44+
// expected-error@Inputs/sycl.hpp:233 {{SYCL 1.2.1 specification requires an explicit forward declaration for a kernel type name; your program may not be portable}}
4545
#endif
4646

4747
q.submit([&](handler &h) {
@@ -53,9 +53,9 @@ int main() {
5353
});
5454

5555
#if defined(WARN)
56-
// expected-warning@Inputs/sycl.hpp:220 {{SYCL 1.2.1 specification requires an explicit forward declaration for a kernel type name; your program may not be portable}}
56+
// expected-warning@Inputs/sycl.hpp:233 {{SYCL 1.2.1 specification requires an explicit forward declaration for a kernel type name; your program may not be portable}}
5757
#elif defined(ERROR)
58-
// expected-error@Inputs/sycl.hpp:220 {{SYCL 1.2.1 specification requires an explicit forward declaration for a kernel type name; your program may not be portable}}
58+
// expected-error@Inputs/sycl.hpp:233 {{SYCL 1.2.1 specification requires an explicit forward declaration for a kernel type name; your program may not be portable}}
5959
#endif
6060

6161
q.submit([&](handler &h) {

0 commit comments

Comments
 (0)