Skip to content

Commit 1f3c101

Browse files
[SYCL] Enable using of the copy method with shared_ptr with const T.
Signed-off-by: Alexey Voronov <[email protected]>
1 parent 4646fc1 commit 1f3c101

File tree

3 files changed

+37
-10
lines changed

3 files changed

+37
-10
lines changed

sycl/include/CL/sycl/detail/cg.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ class CG {
325325

326326
CG(CGTYPE Type, std::vector<std::vector<char>> ArgsStorage,
327327
std::vector<detail::AccessorImplPtr> AccStorage,
328-
std::vector<std::shared_ptr<void>> SharedPtrStorage,
328+
std::vector<std::shared_ptr<const void>> SharedPtrStorage,
329329
std::vector<Requirement *> Requirements)
330330
: MType(Type), MArgsStorage(std::move(ArgsStorage)),
331331
MAccStorage(std::move(AccStorage)),
@@ -347,7 +347,7 @@ class CG {
347347
// Storage for accessors.
348348
std::vector<detail::AccessorImplPtr> MAccStorage;
349349
// Storage for shared_ptrs.
350-
std::vector<std::shared_ptr<void>> MSharedPtrStorage;
350+
std::vector<std::shared_ptr<const void>> MSharedPtrStorage;
351351
// List of requirements that specify which memory is needed for the command
352352
// group to be executed.
353353
std::vector<Requirement *> MRequirements;
@@ -368,7 +368,7 @@ class CGExecKernel : public CG {
368368
std::shared_ptr<detail::kernel_impl> SyclKernel,
369369
std::vector<std::vector<char>> ArgsStorage,
370370
std::vector<detail::AccessorImplPtr> AccStorage,
371-
std::vector<std::shared_ptr<void>> SharedPtrStorage,
371+
std::vector<std::shared_ptr<const void>> SharedPtrStorage,
372372
std::vector<Requirement *> Requirements,
373373
std::vector<ArgDesc> Args, std::string KernelName,
374374
detail::OSModuleHandle OSModuleHandle,
@@ -396,7 +396,7 @@ class CGCopy : public CG {
396396
CGCopy(CGTYPE CopyType, void *Src, void *Dst,
397397
std::vector<std::vector<char>> ArgsStorage,
398398
std::vector<detail::AccessorImplPtr> AccStorage,
399-
std::vector<std::shared_ptr<void>> SharedPtrStorage,
399+
std::vector<std::shared_ptr<const void>> SharedPtrStorage,
400400
std::vector<Requirement *> Requirements)
401401
: CG(CopyType, std::move(ArgsStorage), std::move(AccStorage),
402402
std::move(SharedPtrStorage), std::move(Requirements)),
@@ -414,7 +414,7 @@ class CGFill : public CG {
414414
CGFill(std::vector<char> Pattern, void *Ptr,
415415
std::vector<std::vector<char>> ArgsStorage,
416416
std::vector<detail::AccessorImplPtr> AccStorage,
417-
std::vector<std::shared_ptr<void>> SharedPtrStorage,
417+
std::vector<std::shared_ptr<const void>> SharedPtrStorage,
418418
std::vector<Requirement *> Requirements)
419419
: CG(FILL, std::move(ArgsStorage), std::move(AccStorage),
420420
std::move(SharedPtrStorage), std::move(Requirements)),
@@ -429,7 +429,7 @@ class CGUpdateHost : public CG {
429429
public:
430430
CGUpdateHost(void *Ptr, std::vector<std::vector<char>> ArgsStorage,
431431
std::vector<detail::AccessorImplPtr> AccStorage,
432-
std::vector<std::shared_ptr<void>> SharedPtrStorage,
432+
std::vector<std::shared_ptr<const void>> SharedPtrStorage,
433433
std::vector<Requirement *> Requirements)
434434
: CG(UPDATE_HOST, std::move(ArgsStorage), std::move(AccStorage),
435435
std::move(SharedPtrStorage), std::move(Requirements)),
@@ -438,6 +438,6 @@ class CGUpdateHost : public CG {
438438
Requirement *getReqToUpdate() { return MPtr; }
439439
};
440440

441-
} // namespace cl
442-
} // namespace sycl
443441
} // namespace detail
442+
} // namespace sycl
443+
} // namespace cl

sycl/include/CL/sycl/handler.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class handler {
141141
std::vector<std::vector<char>> MArgsStorage;
142142
std::vector<detail::AccessorImplPtr> MAccStorage;
143143
std::vector<std::shared_ptr<detail::stream_impl>> MStreamStorage;
144-
std::vector<std::shared_ptr<void>> MSharedPtrStorage;
144+
std::vector<std::shared_ptr<const void>> MSharedPtrStorage;
145145
// The list of arguments for the kernel.
146146
std::vector<detail::ArgDesc> MArgs;
147147
// The list of associated accessors with this handler.
@@ -924,7 +924,7 @@ class handler {
924924
// Make sure data shared_ptr points to is not released until we finish
925925
// work with it.
926926
MSharedPtrStorage.push_back(Src);
927-
T_Dst *RawSrcPtr = Src.get();
927+
T_Src *RawSrcPtr = Src.get();
928928
copy(RawSrcPtr, Dst);
929929
}
930930

sycl/test/basic_tests/handler/handler_mem_op.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ template <typename T> void test_fill(T Val);
3232
template <typename T> void test_copy_ptr_acc();
3333
template <typename T> void test_copy_acc_ptr();
3434
template <typename T> void test_copy_shared_ptr_acc();
35+
template <typename T> void test_copy_shared_ptr_const_acc();
3536
template <typename T> void test_copy_acc_shared_ptr();
3637
template <typename T> void test_copy_acc_acc();
3738
template <typename T> void test_update_host();
@@ -73,6 +74,14 @@ int main() {
7374
test_copy_shared_ptr_acc<point<int>>();
7475
test_copy_shared_ptr_acc<point<float>>();
7576
}
77+
// handler.copy(const shared_ptr, acc)
78+
{
79+
test_copy_shared_ptr_const_acc<int>();
80+
test_copy_shared_ptr_const_acc<int>();
81+
test_copy_shared_ptr_const_acc<point<int>>();
82+
test_copy_shared_ptr_const_acc<point<int>>();
83+
test_copy_shared_ptr_const_acc<point<float>>();
84+
}
7685
// handler.copy(acc, shared_ptr)
7786
{
7887
test_copy_acc_shared_ptr<int>();
@@ -202,6 +211,24 @@ template <typename T> void test_copy_shared_ptr_acc() {
202211
}
203212
}
204213

214+
template <typename T> void test_copy_shared_ptr_const_acc() {
215+
constexpr size_t Size = 10;
216+
T Data[Size] = {0};
217+
std::shared_ptr<const T> Values(new T[Size]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
218+
{
219+
buffer<T, 1> Buffer(Data, range<1>(Size));
220+
queue Queue;
221+
Queue.submit([&](handler &Cgh) {
222+
accessor<T, 1, access::mode::write, access::target::global_buffer>
223+
Accessor(Buffer, Cgh, range<1>(Size));
224+
Cgh.copy(Values, Accessor);
225+
});
226+
}
227+
for (size_t I = 0; I < Size; ++I) {
228+
assert(Data[I] == Values.get()[I]);
229+
}
230+
}
231+
205232
template <typename T> void test_copy_acc_shared_ptr() {
206233
const size_t Size = 10;
207234
T Data[Size] = {0};

0 commit comments

Comments
 (0)