Skip to content

Commit 51c2fd8

Browse files
committed
[SYCL] Add __set_range call for accessor class
Added creation of new kernel argument Added __set_range call after __set_pointer call Signed-off-by: Vladimir Lazarev <[email protected]>
1 parent dd5ff0a commit 51c2fd8

File tree

5 files changed

+167
-45
lines changed

5 files changed

+167
-45
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 157 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
#include "clang/AST/RecordLayout.h"
1515
#include "clang/AST/RecursiveASTVisitor.h"
1616
#include "clang/Sema/Sema.h"
17-
#include "llvm/ADT/SmallVector.h"
1817
#include "llvm/ADT/SmallPtrSet.h"
18+
#include "llvm/ADT/SmallVector.h"
1919
#include "llvm/Support/FileSystem.h"
2020
#include "llvm/Support/Path.h"
2121
#include "llvm/Support/raw_ostream.h"
@@ -154,14 +154,28 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
154154
S.Context, NestedNameSpecifierLoc(), SourceLocation(), LambdaVD, false,
155155
DeclarationNameInfo(), QualType(LC->getTypeForDecl(), 0), VK_LValue);
156156

157-
// Initialize Lambda fields
158-
llvm::SmallVector<Expr *, 16> InitCaptures;
159-
160157
auto TargetFunc = dyn_cast<FunctionDecl>(DC);
161158
auto TargetFuncParam =
162159
TargetFunc->param_begin(); // Iterator to ParamVarDecl (VarDecl)
163160
if (TargetFuncParam) {
164161
for (auto Field : LC->fields()) {
162+
auto getExprForPointer = [](Sema &S, const QualType &paramTy,
163+
DeclRefExpr *DRE) {
164+
// C++ address space attribute != OpenCL address space attribute
165+
Expr *qualifiersCast = ImplicitCastExpr::Create(
166+
S.Context, paramTy, CK_NoOp, DRE, nullptr, VK_LValue);
167+
Expr *Res =
168+
ImplicitCastExpr::Create(S.Context, paramTy, CK_LValueToRValue,
169+
qualifiersCast, nullptr, VK_RValue);
170+
return Res;
171+
};
172+
auto getExprForRange = [](Sema &S, const QualType &paramTy,
173+
DeclRefExpr *DRE) {
174+
Expr *Res = ImplicitCastExpr::Create(S.Context, paramTy, CK_NoOp, DRE,
175+
nullptr, VK_RValue);
176+
return Res;
177+
};
178+
165179
QualType ParamType = (*TargetFuncParam)->getOriginalType();
166180
auto DRE =
167181
DeclRefExpr::Create(S.Context, NestedNameSpecifierLoc(),
@@ -171,18 +185,20 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
171185
QualType FieldType = Field->getType();
172186
CXXRecordDecl *CRD = FieldType->getAsCXXRecordDecl();
173187
if (CRD) {
174-
llvm::SmallVector<Expr *, 16> ParamStmts;
175188
DeclAccessPair FieldDAP = DeclAccessPair::make(Field, AS_none);
189+
// lambda.accessor
176190
auto AccessorME = MemberExpr::Create(
177191
S.Context, LambdaDRE, false, SourceLocation(),
178192
NestedNameSpecifierLoc(), SourceLocation(), Field, FieldDAP,
179193
DeclarationNameInfo(Field->getDeclName(), SourceLocation()),
180194
nullptr, Field->getType(), VK_LValue, OK_Ordinary);
181-
195+
bool PointerOfAccesorWasSet = false;
182196
for (auto Method : CRD->methods()) {
197+
llvm::SmallVector<Expr *, 16> ParamStmts;
183198
if (Method->getNameInfo().getName().getAsString() ==
184199
"__set_pointer") {
185200
DeclAccessPair MethodDAP = DeclAccessPair::make(Method, AS_none);
201+
// lambda.accessor.__set_pointer
186202
auto ME = MemberExpr::Create(
187203
S.Context, AccessorME, false, SourceLocation(),
188204
NestedNameSpecifierLoc(), SourceLocation(), Method, MethodDAP,
@@ -199,19 +215,75 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
199215
// __set_pointer needs one parameter
200216
QualType paramTy = (*(Method->param_begin()))->getOriginalType();
201217

202-
// C++ address space attribute != OpenCL address space attribute
203-
Expr *qualifiersCast = ImplicitCastExpr::Create(
204-
S.Context, paramTy, CK_NoOp, DRE, nullptr, VK_LValue);
205-
Expr *Res = ImplicitCastExpr::Create(
206-
S.Context, paramTy, CK_LValueToRValue, qualifiersCast,
207-
nullptr, VK_RValue);
218+
Expr *Res = getExprForPointer(S, paramTy, DRE);
208219

220+
// kernel_parameter
209221
ParamStmts.push_back(Res);
210-
211222
// lambda.accessor.__set_pointer(kernel_parameter)
212223
CXXMemberCallExpr *Call = CXXMemberCallExpr::Create(
213224
S.Context, ME, ParamStmts, ResultTy, VK, SourceLocation());
214225
BodyStmts.push_back(Call);
226+
PointerOfAccesorWasSet = true;
227+
}
228+
}
229+
if (PointerOfAccesorWasSet) {
230+
TargetFuncParam++;
231+
232+
ParamType = (*TargetFuncParam)->getOriginalType();
233+
DRE = DeclRefExpr::Create(S.Context, NestedNameSpecifierLoc(),
234+
SourceLocation(), *TargetFuncParam, false,
235+
DeclarationNameInfo(), ParamType,
236+
VK_LValue);
237+
238+
FieldType = Field->getType();
239+
CRD = FieldType->getAsCXXRecordDecl();
240+
if (CRD) {
241+
FieldDAP = DeclAccessPair::make(Field, AS_none);
242+
// lambda.accessor
243+
AccessorME = MemberExpr::Create(
244+
S.Context, LambdaDRE, false, SourceLocation(),
245+
NestedNameSpecifierLoc(), SourceLocation(), Field, FieldDAP,
246+
DeclarationNameInfo(Field->getDeclName(), SourceLocation()),
247+
nullptr, Field->getType(), VK_LValue, OK_Ordinary);
248+
249+
for (auto Method : CRD->methods()) {
250+
llvm::SmallVector<Expr *, 16> ParamStmts;
251+
if (Method->getNameInfo().getName().getAsString() ==
252+
"__set_range") {
253+
// lambda.accessor.__set_range
254+
DeclAccessPair MethodDAP =
255+
DeclAccessPair::make(Method, AS_none);
256+
auto ME = MemberExpr::Create(
257+
S.Context, AccessorME, false, SourceLocation(),
258+
NestedNameSpecifierLoc(), SourceLocation(), Method,
259+
MethodDAP, Method->getNameInfo(), nullptr,
260+
Method->getType(), VK_LValue, OK_Ordinary);
261+
262+
// Not referenced -> not emitted
263+
S.MarkFunctionReferenced(SourceLocation(), Method, true);
264+
265+
QualType ResultTy = Method->getReturnType();
266+
ExprValueKind VK = Expr::getValueKindForType(ResultTy);
267+
ResultTy = ResultTy.getNonLValueExprType(S.Context);
268+
269+
// __set_range needs one parameter
270+
QualType paramTy =
271+
(*(Method->param_begin()))->getOriginalType();
272+
273+
Expr *Res = getExprForRange(S, paramTy, DRE);
274+
275+
// kernel_parameter
276+
ParamStmts.push_back(Res);
277+
// lambda.accessor.__set_range(kernel_parameter)
278+
CXXMemberCallExpr *Call = CXXMemberCallExpr::Create(
279+
S.Context, ME, ParamStmts, ResultTy, VK,
280+
SourceLocation());
281+
BodyStmts.push_back(Call);
282+
}
283+
}
284+
} else {
285+
llvm_unreachable(
286+
"unsupported accessor and without initialized range");
215287
}
216288
}
217289
} else if (FieldType->isBuiltinType()) {
@@ -279,6 +351,7 @@ class Util {
279351
/// invocation.
280352
enum VisitorContext {
281353
pre_visit,
354+
pre_visit_class_field,
282355
visit_accessor,
283356
visit_scalar,
284357
visit_stream,
@@ -308,7 +381,7 @@ static void visitKernelLambdaCaptures(const CXXRecordDecl *Lambda,
308381
QualType ArgTy = V->getType();
309382
auto F1 = std::get<pre_visit>(Vis);
310383
F1(Cnt, V, *Fld);
311-
384+
FieldDecl *AccessorRangeField = nullptr;
312385
if (Util::isSyclAccessorType(ArgTy)) {
313386
// accessor parameter context
314387
const auto *RecordDecl = ArgTy->getAsCXXRecordDecl();
@@ -317,6 +390,26 @@ static void visitKernelLambdaCaptures(const CXXRecordDecl *Lambda,
317390
dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl);
318391
assert(TemplateDecl && "templated accessor type expected");
319392

393+
auto getFieldByName = [](const CXXRecordDecl *RecordDecl,
394+
std::string Name) {
395+
FieldDecl *result = nullptr;
396+
for (auto jt = RecordDecl->field_begin(); jt != RecordDecl->field_end();
397+
++jt) {
398+
if (jt->getNameAsString() == Name) {
399+
result = *jt;
400+
break;
401+
}
402+
}
403+
return result;
404+
};
405+
FieldDecl *AccessorImplField = getFieldByName(RecordDecl, "__impl");
406+
assert(AccessorImplField && "no __impl found in accessor");
407+
const auto *AccessorImplRecord =
408+
AccessorImplField->getType()->getAsCXXRecordDecl();
409+
assert(AccessorImplRecord && "accessor __impl must be of a record type");
410+
AccessorRangeField = getFieldByName(AccessorImplRecord, "Range");
411+
assert(AccessorRangeField && "no Range found in __impl of accessor");
412+
320413
// First accessor template parameter - data type
321414
QualType PointeeType = TemplateDecl->getTemplateArgs()[0].getAsType();
322415
// Fourth parameter - access target
@@ -335,9 +428,20 @@ static void visitKernelLambdaCaptures(const CXXRecordDecl *Lambda,
335428
} else {
336429
llvm_unreachable("unsupported kernel parameter type");
337430
}
338-
// pos-visit context
431+
// post-visit context
339432
auto F2 = std::get<post_visit>(Vis);
340433
F2(Cnt, V, *Fld);
434+
435+
if (AccessorRangeField) {
436+
// pre-visit context the same like for accessor
437+
auto F1Range = std::get<pre_visit_class_field>(Vis);
438+
F1Range(Cnt, V, *Fld, AccessorRangeField);
439+
auto FRange = std::get<visit_scalar>(Vis);
440+
FRange(Cnt, V, AccessorRangeField);
441+
// post-visit context
442+
auto F2Range = std::get<post_visit>(Vis);
443+
F2Range(Cnt, nullptr, AccessorRangeField);
444+
}
341445
}
342446
assert((Cpt == CptEnd) && (Fld == FldEnd) &&
343447
"captures inconsistent with fields");
@@ -350,6 +454,8 @@ static void BuildArgTys(ASTContext &Context, CXXRecordDecl *Lambda,
350454
auto Vis = std::make_tuple(
351455
// pre_visit
352456
[&](int, VarDecl *, FieldDecl *) {},
457+
// pre_visit_class_field
458+
[&](int, VarDecl *, FieldDecl *, FieldDecl *) {},
353459
// visit_accessor
354460
[&](int CaptureN, target AccTrg, QualType PointeeType,
355461
DeclaratorDecl *CapturedVar, FieldDecl *CapturedVal) {
@@ -390,7 +496,10 @@ static void BuildArgTys(ASTContext &Context, CXXRecordDecl *Lambda,
390496
IdentifierInfo *VarName = 0;
391497
SmallString<8> Str;
392498
llvm::raw_svector_ostream OS(Str);
393-
OS << "_arg_" << CapturedVar->getIdentifier()->getName();
499+
IdentifierInfo *Identifier = (CapturedVar != nullptr)
500+
? CapturedVar->getIdentifier()
501+
: CapturedVal->getIdentifier();
502+
OS << "_arg_" << Identifier->getName();
394503
VarName = &Context.Idents.get(OS.str());
395504

396505
auto NewVarDecl = VarDecl::Create(
@@ -422,7 +531,24 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
422531
[&](int CaptureN, VarDecl *CapturedVar, FieldDecl *CapturedVal) {
423532
// Set offset in bytes
424533
Offset = static_cast<unsigned>(
425-
Layout.getFieldOffset(CapturedVal->getFieldIndex()))/8;
534+
Layout.getFieldOffset(CapturedVal->getFieldIndex())) /
535+
8;
536+
},
537+
// pre_visit_class_field
538+
[&](int CaptureN, VarDecl *CapturedVar, FieldDecl *CapturedVal,
539+
FieldDecl *MemberVal) {
540+
// Set offset of parent in bytes
541+
Offset = static_cast<unsigned>(
542+
Layout.getFieldOffset(CapturedVal->getFieldIndex())) /
543+
8;
544+
const RecordDecl *parent = MemberVal->getParent();
545+
ASTContext &CtxMember = parent->getASTContext();
546+
const ASTRecordLayout &LayoutMember =
547+
CtxMember.getASTRecordLayout(parent);
548+
// Add offset relative to parent in bytes
549+
Offset += static_cast<unsigned>(
550+
LayoutMember.getFieldOffset(MemberVal->getFieldIndex())) /
551+
8;
426552
},
427553
// visit_accessor
428554
[&](int CaptureN, target AccTrg, QualType PointeeType,
@@ -453,15 +579,12 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
453579
// are removed to make the name shorter. Non-alphanumeric characters in a kernel
454580
// name are OK - SPIRV and runtimes allow that.
455581
static std::string constructKernelName(QualType KernelNameType) {
456-
static const std::string Kwds[] = {
457-
std::string("class"),
458-
std::string("struct")
459-
};
582+
static const std::string Kwds[] = {std::string("class"),
583+
std::string("struct")};
460584
std::string TStr = KernelNameType.getAsString();
461585

462586
for (const std::string &Kwd : Kwds) {
463-
for (size_t Pos = TStr.find(Kwd);
464-
Pos != StringRef::npos;
587+
for (size_t Pos = TStr.find(Kwd); Pos != StringRef::npos;
465588
Pos = TStr.find(Kwd, Pos)) {
466589

467590
size_t EndPos = Pos + Kwd.length();
@@ -593,12 +716,13 @@ static void printDecl(raw_ostream &O, const Decl *D) {
593716
// \param Depth
594717
// recursion depth
595718
//
596-
static void emitForwardClassDecls(raw_ostream &O,
597-
QualType T,
598-
llvm::SmallPtrSetImpl<const void*> &Printed) {
719+
static void
720+
emitForwardClassDecls(raw_ostream &O, QualType T,
721+
llvm::SmallPtrSetImpl<const void *> &Printed) {
599722

600723
// peel off the pointer types and get the class/struct type:
601-
for (; T->isPointerType(); T = T->getPointeeType());
724+
for (; T->isPointerType(); T = T->getPointeeType())
725+
;
602726
const CXXRecordDecl *RD = T->getAsCXXRecordDecl();
603727

604728
if (!RD)
@@ -657,7 +781,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
657781
O << "\n";
658782
O << "// Forward declarations of templated kernel function types:\n";
659783

660-
llvm::SmallPtrSet<const void*, 4> Printed;
784+
llvm::SmallPtrSet<const void *, 4> Printed;
661785

662786
for (const KernelDesc &K : KernelDescs) {
663787
emitForwardClassDecls(O, K.NameType, Printed);
@@ -737,12 +861,11 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
737861

738862
for (const KernelDesc &K : KernelDescs) {
739863
const size_t N = K.Params.size();
740-
O << "template <> struct KernelInfo<" <<
741-
K.NameType.getAsString() << "> {\n";
742-
O << " static constexpr const char* getName() { return \""
743-
<< K.Name << "\"; }\n";
744-
O << " static constexpr unsigned getNumParams() { return "
745-
<< N << "; }\n";
864+
O << "template <> struct KernelInfo<" << K.NameType.getAsString()
865+
<< "> {\n";
866+
O << " static constexpr const char* getName() { return \"" << K.Name
867+
<< "\"; }\n";
868+
O << " static constexpr unsigned getNumParams() { return " << N << "; }\n";
746869
O << " static constexpr const kernel_param_desc_t& ";
747870
O << "getParamDesc(unsigned i) {\n";
748871
O << " return kernel_signatures[i+" << CurStart << "];\n";
@@ -757,7 +880,6 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
757880
O << "\n";
758881
}
759882

760-
761883
bool SYCLIntegrationHeader::emit(const StringRef &IntHeaderName) {
762884
if (IntHeaderName.empty())
763885
return false;

clang/test/CodeGenSYCL/address-space-parameter-conversions.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %clang -cc1 -triple spir64-unknown-linux-sycldevice -std=c++11 -fsycl-is-device -emit-llvm -x c++ %s -o - | FileCheck %s
1+
// RUN: %clang_cc1 -triple spir64-unknown-linux-sycldevice -std=c++11 -fsycl-is-device -disable-llvm-passes -emit-llvm -x c++ %s -o - | opt -asfix -S -o - | FileCheck %s
22
void bar(int & Data) {}
33
// CHECK: define spir_func void @[[RAW_REF:[a-zA-Z0-9_]+]](i32* dereferenceable(4) %
44
void bar2(int & Data) {}
@@ -144,13 +144,13 @@ void usages2() {
144144
// CHECK: call spir_func void @new.[[RAW_REF2]](i32 addrspace(4)* [[LOCAL_CAST]])
145145
}
146146

147-
// CHECK: define spir_func void @new.[[RAW_REF]](i32 addrspace(4)* dereferenceable(4)
147+
// CHECK-DAG: define spir_func void @new.[[RAW_REF]](i32 addrspace(4)* dereferenceable(4)
148148

149-
// CHECK: define spir_func void @new.[[RAW_REF2]](i32 addrspace(4)* dereferenceable(4)
149+
// CHECK-DAG: define spir_func void @new.[[RAW_REF2]](i32 addrspace(4)* dereferenceable(4)
150150

151-
// CHECK: define spir_func void @new.[[RAW_PTR2]](i32 addrspace(4)*
151+
// CHECK-DAG: define spir_func void @new.[[RAW_PTR]](i32 addrspace(4)*
152152

153-
// CHECK: define spir_func void @new.[[RAW_PTR]](i32 addrspace(4)*
153+
// CHECK-DAG: define spir_func void @new.[[RAW_PTR2]](i32 addrspace(4)*
154154

155155
template <typename name, typename Func>
156156
__attribute__((sycl_kernel)) void kernel_single_task(Func kernelFunc) {

clang/test/SemaSYCL/accessors-targets.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ int main() {
2424

2525
myQueue.wait();
2626
}
27-
// CHECK: kernel_function 'void (__local int *__local, __global int *__global)'
27+
// CHECK: kernel_function 'void (__local int *__local, range<1>, __global int *__global, range<1>)'

clang/test/SemaSYCL/built-in-type-kernel-arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ int main() {
1515
});
1616
return 0;
1717
}
18-
// CHECK: kernel_function 'void (__global int *__global, int)
18+
// CHECK: kernel_function 'void (__global int *__global, range<1>, int)

clang/test/SemaSYCL/fake-accessors.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,6 @@ int main() {
5050
});
5151
return 0;
5252
}
53-
// CHECK: fake_accessors 'void (__global int *__global, foo::cl::sycl::accessor, accessor)
54-
// CHECK: accessor_typedef 'void (__global int *__global, foo::cl::sycl::accessor, accessor)
55-
// CHECK: accessor_alias 'void (__global int *__global, foo::cl::sycl::accessor, accessor)
53+
// CHECK: fake_accessors 'void (__global int *__global, range<1>, foo::cl::sycl::accessor, accessor)
54+
// CHECK: accessor_typedef 'void (__global int *__global, range<1>, foo::cl::sycl::accessor, accessor)
55+
// CHECK: accessor_alias 'void (__global int *__global, range<1>, foo::cl::sycl::accessor, accessor)

0 commit comments

Comments
 (0)