Skip to content

Commit 36f6ab6

Browse files
[SYCL] Fix crash when kernel argument is a multi-dimensional array. (#2341)
This patch fixes crash due to incorrect InitializedEntity for multi-dimensional arrays. When generating the InitializedEntity for an element, it is necessary to descend the array. For example, the initialized entity for s.array[x][y][z] is constructed using initialized entities for s.array[x][y], s.array[x] and s.array. Prior to this patch, the 'descending' was not done. Patch by: Rajiv Deodhar and Elizabeth Andrews Signed-off-by: Elizabeth Andrews <[email protected]>
1 parent 96da39e commit 36f6ab6

File tree

5 files changed

+239
-42
lines changed

5 files changed

+239
-42
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 90 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -876,11 +876,12 @@ class KernelObjVisitor {
876876

877877
assert(ElemCount > 0 && "SYCL prohibits 0 sized arrays");
878878
VisitFirstElement(nullptr, FD, ET, handlers...);
879-
(void)std::initializer_list<int>{(handlers.nextElement(ET), 0)...};
879+
(void)std::initializer_list<int>{(handlers.nextElement(ET, 1), 0)...};
880880

881881
for (int64_t Count = 1; Count < ElemCount; Count++) {
882882
VisitNthElement(nullptr, FD, ET, handlers...);
883-
(void)std::initializer_list<int>{(handlers.nextElement(ET), 0)...};
883+
(void)std::initializer_list<int>{
884+
(handlers.nextElement(ET, Count + 1), 0)...};
884885
}
885886

886887
(void)std::initializer_list<int>{
@@ -1085,7 +1086,7 @@ class SyclKernelFieldHandlerBase {
10851086
virtual bool enterField(const CXXRecordDecl *, FieldDecl *) { return true; }
10861087
virtual bool leaveField(const CXXRecordDecl *, FieldDecl *) { return true; }
10871088
virtual bool enterArray() { return true; }
1088-
virtual bool nextElement(QualType) { return true; }
1089+
virtual bool nextElement(QualType, uint64_t) { return true; }
10891090
virtual bool leaveArray(FieldDecl *, QualType, int64_t) { return true; }
10901091

10911092
virtual ~SyclKernelFieldHandlerBase() = default;
@@ -1665,7 +1666,6 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
16651666
InitializedEntity VarEntity;
16661667
const CXXRecordDecl *KernelObj;
16671668
llvm::SmallVector<Expr *, 16> MemberExprBases;
1668-
uint64_t ArrayIndex;
16691669
FunctionDecl *KernelCallerFunc;
16701670

16711671
// Using the statements/init expressions that we've created, this generates
@@ -1778,17 +1778,62 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
17781778
InitExprs.push_back(MemberInit.get());
17791779
}
17801780

1781+
int getDims() {
1782+
int Dims = 0;
1783+
for (int i = MemberExprBases.size() - 1; i >= 0; --i) {
1784+
if (!isa<ArraySubscriptExpr>(MemberExprBases[i]))
1785+
break;
1786+
++Dims;
1787+
}
1788+
return Dims;
1789+
}
1790+
1791+
int64_t getArrayIndex(int Idx) {
1792+
ArraySubscriptExpr *LastArrayRef =
1793+
cast<ArraySubscriptExpr>(MemberExprBases[Idx]);
1794+
Expr *LastIdx = LastArrayRef->getIdx();
1795+
llvm::APSInt Result;
1796+
SemaRef.VerifyIntegerConstantExpression(LastIdx, &Result);
1797+
return Result.getExtValue();
1798+
}
1799+
17811800
void createExprForScalarElement(FieldDecl *FD) {
1782-
InitializedEntity ArrayEntity =
1801+
llvm::SmallVector<InitializedEntity, 4> InitEntities;
1802+
1803+
// For multi-dimensional arrays, an initialized entity needs to be
1804+
// generated for each 'dimension'. For example, the initialized entity
1805+
// for s.array[x][y][z] is constructed using initialized entities for
1806+
// s.array[x][y], s.array[x] and s.array. InitEntities is used to maintain
1807+
// this.
1808+
InitializedEntity Entity =
17831809
InitializedEntity::InitializeMember(FD, &VarEntity);
1810+
InitEntities.push_back(Entity);
1811+
1812+
// Calculate dimension using ArraySubscriptExpressions in MemberExprBases.
1813+
// Each dimension has an ArraySubscriptExpression (maintains index)
1814+
// in MemberExprBases. For example, if we are currently handling element
1815+
// a[0][0][1], the top of stack entries are ArraySubscriptExpressions for
1816+
// indices 0,0 and 1, with 1 on top.
1817+
int Dims = getDims();
1818+
1819+
// MemberExprBasesIdx is used to get the index of each dimension, in correct
1820+
// order, from MemberExprBases. For example for a[0][0][1], getArrayIndex
1821+
// will return 0, 0 and then 1.
1822+
int MemberExprBasesIdx = MemberExprBases.size() - Dims;
1823+
for (int I = 0; I < Dims; ++I) {
1824+
InitializedEntity NewEntity = InitializedEntity::InitializeElement(
1825+
SemaRef.getASTContext(), getArrayIndex(MemberExprBasesIdx),
1826+
InitEntities.back());
1827+
InitEntities.push_back(NewEntity);
1828+
++MemberExprBasesIdx;
1829+
}
1830+
17841831
InitializationKind InitKind =
17851832
InitializationKind::CreateCopy(SourceLocation(), SourceLocation());
17861833
Expr *DRE = createInitExpr(FD);
1787-
InitializedEntity Entity = InitializedEntity::InitializeElement(
1788-
SemaRef.getASTContext(), ArrayIndex, ArrayEntity);
1789-
ArrayIndex++;
1790-
InitializationSequence InitSeq(SemaRef, Entity, InitKind, DRE);
1791-
ExprResult MemberInit = InitSeq.Perform(SemaRef, Entity, InitKind, DRE);
1834+
InitializationSequence InitSeq(SemaRef, InitEntities.back(), InitKind, DRE);
1835+
ExprResult MemberInit =
1836+
InitSeq.Perform(SemaRef, InitEntities.back(), InitKind, DRE);
17921837
InitExprs.push_back(MemberInit.get());
17931838
}
17941839

@@ -1802,7 +1847,22 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
18021847
Expr *ILE = new (SemaRef.getASTContext())
18031848
InitListExpr(SemaRef.getASTContext(), SourceLocation(), ArrayInitExprs,
18041849
SourceLocation());
1805-
ILE->setType(FD->getType());
1850+
1851+
// We need to find the type of the element for which we are generating the
1852+
// InitListExpr. For example, for a multi-dimensional array say a[2][3][2],
1853+
// the types for InitListExpr of the array and its 'sub-arrays' are -
1854+
// int [2][3][2], int [3][2] and int [2]. This loop is used to obtain this
1855+
// information from MemberExprBases. MemberExprBases holds
1856+
// ArraySubscriptExprs and the top of stack shows how far we have descended
1857+
// down the array. getDims() calculates this depth.
1858+
QualType ILEType = FD->getType();
1859+
for (int I = getDims(); I > 1; I--) {
1860+
const ConstantArrayType *CAT =
1861+
SemaRef.getASTContext().getAsConstantArrayType(ILEType);
1862+
assert(CAT && "Should only be called on constant-size array.");
1863+
ILEType = CAT->getElementType();
1864+
}
1865+
ILE->setType(ILEType);
18061866
InitExprs.push_back(ILE);
18071867
}
18081868

@@ -2063,20 +2123,18 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
20632123
ExprResult ElementBase = SemaRef.CreateBuiltinArraySubscriptExpr(
20642124
ArrayBase, SourceLocation(), IndexExpr.get(), SourceLocation());
20652125
MemberExprBases.push_back(ElementBase.get());
2066-
ArrayIndex = 0;
20672126
return true;
20682127
}
20692128

2070-
bool nextElement(QualType ET) final {
2071-
ArraySubscriptExpr *LastArrayRef =
2072-
cast<ArraySubscriptExpr>(MemberExprBases.back());
2129+
bool nextElement(QualType ET, uint64_t) final {
2130+
// Top of MemberExprBases holds ArraySubscriptExpression of element
2131+
// we just handled, or the Array base for the dimension we are
2132+
// currently visiting.
2133+
int64_t nextIndex = getArrayIndex(MemberExprBases.size() - 1) + 1;
20732134
MemberExprBases.pop_back();
2074-
Expr *LastIdx = LastArrayRef->getIdx();
2075-
llvm::APSInt Result;
2076-
SemaRef.VerifyIntegerConstantExpression(LastIdx, &Result);
20772135
Expr *ArrayBase = MemberExprBases.back();
2078-
ExprResult IndexExpr = SemaRef.ActOnIntegerConstant(
2079-
SourceLocation(), Result.getExtValue() + 1);
2136+
ExprResult IndexExpr =
2137+
SemaRef.ActOnIntegerConstant(SourceLocation(), nextIndex);
20802138
ExprResult ElementBase = SemaRef.CreateBuiltinArraySubscriptExpr(
20812139
ArrayBase, SourceLocation(), IndexExpr.get(), SourceLocation());
20822140
MemberExprBases.push_back(ElementBase.get());
@@ -2101,6 +2159,7 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
21012159
class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
21022160
SYCLIntegrationHeader &Header;
21032161
int64_t CurOffset = 0;
2162+
llvm::SmallVector<size_t, 16> ArrayBaseOffsets;
21042163
int StructDepth = 0;
21052164

21062165
// A series of functions to calculate the change in offset based on the type.
@@ -2248,18 +2307,20 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
22482307
return true;
22492308
}
22502309

2251-
bool nextElement(QualType ET) final {
2252-
CurOffset += SemaRef.getASTContext().getTypeSizeInChars(ET).getQuantity();
2310+
bool enterArray() final {
2311+
ArrayBaseOffsets.push_back(CurOffset);
22532312
return true;
22542313
}
22552314

2256-
bool leaveArray(FieldDecl *, QualType ET, int64_t Count) final {
2257-
int64_t ArraySize =
2258-
SemaRef.getASTContext().getTypeSizeInChars(ET).getQuantity();
2259-
if (!ET->isArrayType()) {
2260-
ArraySize *= Count;
2261-
}
2262-
CurOffset -= ArraySize;
2315+
bool nextElement(QualType ET, uint64_t Index) final {
2316+
int64_t Size = SemaRef.getASTContext().getTypeSizeInChars(ET).getQuantity();
2317+
CurOffset = ArrayBaseOffsets.back() + Size * (Index);
2318+
return true;
2319+
}
2320+
2321+
bool leaveArray(FieldDecl *, QualType ET, int64_t) final {
2322+
CurOffset = ArrayBaseOffsets.back();
2323+
ArrayBaseOffsets.pop_back();
22632324
return true;
22642325
}
22652326
using SyclKernelFieldHandler::enterStruct;

clang/test/CodeGenSYCL/kernel-param-pod-array-ih.cpp

100755100644
Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
// CHECK: static constexpr
1515
// CHECK-NEXT: const char* const kernel_names[] = {
16-
// CHECK-NEXT: "_ZTSZ4mainE8kernel_B"
16+
// CHECK-NEXT: "_ZTSZ4mainE8kernel_B",
17+
// CHECK-NEXT: "_ZTSZ4mainE8kernel_C",
18+
// CHECK-NEXT: "_ZTSZ4mainE8kernel_D"
1719
// CHECK-NEXT: };
1820

1921
// CHECK: static constexpr
@@ -25,14 +27,40 @@
2527
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 12 },
2628
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 16 },
2729
// CHECK-EMPTY:
30+
// CHECK-NEXT: //--- _ZTSZ4mainE8kernel_C
31+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 },
32+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 4 },
33+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 8 },
34+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 12 },
35+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 16 },
36+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 20 },
37+
// CHECK-EMPTY:
38+
// CHECK-NEXT: //--- _ZTSZ4mainE8kernel_D
39+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 },
40+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 4 },
41+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 8 },
42+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 12 },
43+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 16 },
44+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 20 },
45+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 24 },
46+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 28 },
47+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 32 },
48+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 36 },
49+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 40 },
50+
// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 44 },
51+
// CHECK-EMPTY:
2852
// CHECK-NEXT: };
2953

3054
// CHECK: static constexpr
3155
// CHECK-NEXT: const unsigned kernel_signature_start[] = {
32-
// CHECK-NEXT: 0 // _ZTSZ4mainE8kernel_B
56+
// CHECK-NEXT: 0, // _ZTSZ4mainE8kernel_B
57+
// CHECK-NEXT: 6, // _ZTSZ4mainE8kernel_C
58+
// CHECK-NEXT: 13 // _ZTSZ4mainE8kernel_D
3359
// CHECK-NEXT: };
3460

3561
// CHECK: template <> struct KernelInfo<class kernel_B> {
62+
// CHECK: template <> struct KernelInfo<class kernel_C> {
63+
// CHECK: template <> struct KernelInfo<class kernel_D> {
3664

3765
#include <sycl.hpp>
3866

@@ -46,9 +74,21 @@ __attribute__((sycl_kernel)) void a_kernel(const Func &kernelFunc) {
4674
int main() {
4775

4876
int a[5];
77+
int b[2][3];
78+
int c[2][3][2];
4979

5080
a_kernel<class kernel_B>(
5181
[=]() {
5282
int local = a[3];
5383
});
84+
85+
a_kernel<class kernel_C>(
86+
[=]() {
87+
int local = b[0][1];
88+
});
89+
90+
a_kernel<class kernel_D>(
91+
[=]() {
92+
int local = c[0][1][1];
93+
});
5494
}

clang/test/CodeGenSYCL/kernel-param-pod-array.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ struct foo {
2525
int main() {
2626

2727
int a[2];
28+
int array_2D[2][1];
2829
foo struct_array[2];
2930

3031
a_kernel<class kernel_B>(
@@ -36,6 +37,11 @@ int main() {
3637
[=]() {
3738
foo local = struct_array[1];
3839
});
40+
41+
a_kernel<class kernel_D>(
42+
[=]() {
43+
int local = array_2D[0][0];
44+
});
3945
}
4046

4147
// Check kernel_B parameters
@@ -151,3 +157,25 @@ int main() {
151157
// CHECK: [[GEP_FOO2_C:%[a-zA-Z0-9_]+]] = getelementptr inbounds %struct.{{.*}}foo.foo, %struct.{{.*}}foo.foo* [[FOO_ARRAY_1]], i32 0, i32 2
152158
// CHECK: [[LOAD_FOO2_C:%[a-zA-Z0-9_]+]] = load i32, i32* [[FOO2_C_LOCAL]], align 4
153159
// CHECK: store i32 [[LOAD_FOO2_C]], i32* [[GEP_FOO2_C]], align 4
160+
161+
// Check kernel_D parameters
162+
// CHECK: define spir_kernel void @{{.*}}kernel_D
163+
// CHECK-SAME: i32 [[ARR_2D_1:%[a-zA-Z0-9_]+]], i32 [[ARR_2D_2:%[a-zA-Z0-9_]+]]
164+
165+
// Check local lambda object alloca
166+
// CHECK: [[LAMBDA_OBJ:%[0-9]+]] = alloca %"class.{{.*}}.anon.1", align 4
167+
168+
// Check local stores
169+
// CHECK: store i32 [[ARR_2D_1]], i32* [[ARR_2D_1_LOCAL:%[a-zA-Z_]+.addr[0-9]*]], align 4
170+
// CHECK: store i32 [[ARR_2D_2]], i32* [[ARR_2D_2_LOCAL:%[a-zA-Z_]+.addr[0-9]*]], align 4
171+
172+
// Check initialization of local array
173+
// CHECK: [[GEP_ARR_2D:%[0-9]*]] = getelementptr inbounds %"class._ZTSZ4mainE3$_0.anon.1", %"class._ZTSZ4mainE3$_0.anon.1"* [[LAMBDA_OBJ]], i32 0, i32 0
174+
// CHECK: [[GEP_ARR_BEGIN1:%[a-zA-Z0-9_.]+]] = getelementptr inbounds [2 x [1 x i32]], [2 x [1 x i32]]* [[GEP_ARR_2D]], i64 0, i64 0
175+
// CHECK: [[GEP_ARR_ELEM0:%[a-zA-Z0-9_.]+]] = getelementptr inbounds [1 x i32], [1 x i32]* [[GEP_ARR_BEGIN1]], i64 0, i64 0
176+
// CHECK: [[ARR_2D_ELEM0:%[0-9]*]] = load i32, i32* [[ARR_2D_1_LOCAL]], align 4
177+
// CHECK: store i32 [[ARR_2D_ELEM0]], i32* [[GEP_ARR_ELEM0]], align 4
178+
// CHECK: [[GEP_ARR_BEGIN2:%[a-zA-Z_.]+]] = getelementptr inbounds [1 x i32], [1 x i32]* [[GEP_ARR_BEGIN1]], i64 1
179+
// CHECK: [[GEP_ARR_ELEM1:%[a-zA-Z0-9_.]+]] = getelementptr inbounds [1 x i32], [1 x i32]* [[GEP_ARR_BEGIN2]], i64 0, i64 0
180+
// CHECK: [[ARR_2D_ELEM1:%[0-9]*]] = load i32, i32* [[ARR_2D_2_LOCAL]], align 4
181+
// CHECK: store i32 [[ARR_2D_ELEM1]], i32* [[GEP_ARR_ELEM1]], align 4

0 commit comments

Comments
 (0)