Skip to content

Commit 7feb6d7

Browse files
committed
[SYCL] Fix alignment without breaking the ABI
1 parent 979dd04 commit 7feb6d7

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

sycl/include/CL/sycl/detail/accessor_impl.hpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#pragma once
1010

1111
#include <CL/sycl/access/access.hpp>
12-
#include <CL/sycl/detail/aligned_allocator.hpp>
1312
#include <CL/sycl/detail/export.hpp>
1413
#include <CL/sycl/detail/sycl_mem_obj_i.hpp>
1514
#include <CL/sycl/id.hpp>
@@ -178,22 +177,35 @@ class __SYCL_EXPORT LocalAccessorImplHost {
178177
sycl::range<3> MSize;
179178
int MDims;
180179
int MElemSize;
181-
std::vector<char, aligned_allocator<char>> MMem;
180+
std::vector<char> MMem;
182181
};
183182

184183
using LocalAccessorImplPtr = std::shared_ptr<LocalAccessorImplHost>;
185184

186185
class LocalAccessorBaseHost {
187186
public:
188187
LocalAccessorBaseHost(sycl::range<3> Size, int Dims, int ElemSize) {
188+
// Allocate ElemSize more data to have sufficient padding to enforce
189+
// alignment.
189190
impl = std::shared_ptr<LocalAccessorImplHost>(
190-
new LocalAccessorImplHost(Size, Dims, ElemSize));
191+
new LocalAccessorImplHost(Size + ElemSize, Dims, ElemSize));
191192
}
192193
sycl::range<3> &getSize() { return impl->MSize; }
193194
const sycl::range<3> &getSize() const { return impl->MSize; }
194-
void *getPtr() { return impl->MMem.data(); }
195+
void *getPtr() {
196+
// Const cast this in order to call the const getPtr.
197+
return const_cast<const LocalAccessorBaseHost *>(this)->getPtr();
198+
}
195199
void *getPtr() const {
196-
return const_cast<void *>(reinterpret_cast<void *>(impl->MMem.data()));
200+
char *ptr = impl->MMem.data();
201+
202+
// Align the pointer to MElemSize.
203+
size_t val = reinterpret_cast<size_t>(ptr);
204+
if (val % impl->MElemSize != 0) {
205+
ptr += impl->MElemSize - val % impl->MElemSize;
206+
}
207+
208+
return ptr;
197209
}
198210

199211
int getNumOfDims() { return impl->MDims; }

0 commit comments

Comments
 (0)