Skip to content

Commit 99a9236

Browse files
Pass queue to ptr_adapter (#1129)
* Pass queue to ptr_adapter
1 parent 7d8f2a0 commit 99a9236

18 files changed

+226
-195
lines changed

dpnp/backend/include/dpnp_iface.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,18 @@ INP_DLLEXPORT size_t dpnp_queue_is_cpu_c();
106106
* Memory allocation on the SYCL backend.
107107
*
108108
* @param [in] size_in_bytes Number of bytes for requested memory allocation.
109+
* @param [in] q_ref Reference to SYCL queue.
109110
*
110111
* @return A pointer to newly created memory on @ref dpnp_queue_initialize_c "initialized SYCL device".
111112
*/
113+
INP_DLLEXPORT char* dpnp_memory_alloc_c(DPCTLSyclQueueRef q_ref, size_t size_in_bytes);
112114
INP_DLLEXPORT char* dpnp_memory_alloc_c(size_t size_in_bytes);
113115

116+
INP_DLLEXPORT void dpnp_memory_free_c(DPCTLSyclQueueRef q_ref, void* ptr);
114117
INP_DLLEXPORT void dpnp_memory_free_c(void* ptr);
115-
void dpnp_memory_memcpy_c(void* dst, const void* src, size_t size_in_bytes);
116118

119+
INP_DLLEXPORT void dpnp_memory_memcpy_c(DPCTLSyclQueueRef q_ref, void* dst, const void* src, size_t size_in_bytes);
120+
INP_DLLEXPORT void dpnp_memory_memcpy_c(void* dst, const void* src, size_t size_in_bytes);
117121
/**
118122
* @ingroup BACKEND_API
119123
* @brief Test whether all array elements along a given axis evaluate to True.

dpnp/backend/kernels/dpnp_krnl_arraycreation.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ DPCTLSyclEventRef dpnp_diag_c(DPCTLSyclQueueRef q_ref,
123123

124124
const size_t input1_size = std::accumulate(shape, shape + ndim, 1, std::multiplies<shape_elem_type>());
125125
const size_t result_size = std::accumulate(res_shape, res_shape + res_ndim, 1, std::multiplies<shape_elem_type>());
126-
DPNPC_ptr_adapter<_DataType> input1_ptr(v_in, input1_size, true);
127-
DPNPC_ptr_adapter<_DataType> result_ptr(result1, result_size, true, true);
126+
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref,v_in, input1_size, true);
127+
DPNPC_ptr_adapter<_DataType> result_ptr(q_ref,result1, result_size, true, true);
128128
_DataType* v = input1_ptr.get_ptr();
129129
_DataType* result = result_ptr.get_ptr();
130130

@@ -219,7 +219,7 @@ DPCTLSyclEventRef dpnp_eye_c(DPCTLSyclQueueRef q_ref,
219219

220220
size_t result_size = res_shape[0] * res_shape[1];
221221

222-
DPNPC_ptr_adapter<_DataType> result_ptr(result1, result_size, true, true);
222+
DPNPC_ptr_adapter<_DataType> result_ptr(q_ref,result1, result_size, true, true);
223223
_DataType* result = result_ptr.get_ptr();
224224

225225
int diag_val_;
@@ -506,8 +506,8 @@ DPCTLSyclEventRef dpnp_ptp_c(DPCTLSyclQueueRef q_ref,
506506

507507
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
508508

509-
DPNPC_ptr_adapter<_DataType> input1_ptr(input1_in, input_size, true);
510-
DPNPC_ptr_adapter<_DataType> result_ptr(result1_out, result_size, false, true);
509+
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref,input1_in, input_size, true);
510+
DPNPC_ptr_adapter<_DataType> result_ptr(q_ref,result1_out, result_size, false, true);
511511
_DataType* arr = input1_ptr.get_ptr();
512512
_DataType* result = result_ptr.get_ptr();
513513

@@ -630,8 +630,8 @@ DPCTLSyclEventRef dpnp_vander_c(DPCTLSyclQueueRef q_ref,
630630

631631
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
632632

633-
DPNPC_ptr_adapter<_DataType_input> input1_ptr(array1_in, size_in, true);
634-
DPNPC_ptr_adapter<_DataType_output> result_ptr(result1, size_in * N, true, true);
633+
DPNPC_ptr_adapter<_DataType_input> input1_ptr(q_ref,array1_in, size_in, true);
634+
DPNPC_ptr_adapter<_DataType_output> result_ptr(q_ref,result1, size_in * N, true, true);
635635
const _DataType_input* array_in = input1_ptr.get_ptr();
636636
_DataType_output* result = result_ptr.get_ptr();
637637

@@ -739,7 +739,8 @@ DPCTLSyclEventRef dpnp_trace_c(DPCTLSyclQueueRef q_ref,
739739

740740
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
741741

742-
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, size * last_dim);
742+
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref,array1_in, size * last_dim);
743+
743744
const _DataType* input = input1_ptr.get_ptr();
744745
_ResultType* result = reinterpret_cast<_ResultType*>(result_in);
745746

@@ -926,8 +927,8 @@ DPCTLSyclEventRef dpnp_tril_c(DPCTLSyclQueueRef q_ref,
926927

927928
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
928929

929-
DPNPC_ptr_adapter<_DataType> input1_ptr(array_in, input_size, true);
930-
DPNPC_ptr_adapter<_DataType> result_ptr(result1, res_size, true, true);
930+
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref,array_in, input_size, true);
931+
DPNPC_ptr_adapter<_DataType> result_ptr(q_ref,result1, res_size, true, true);
931932
_DataType* array_m = input1_ptr.get_ptr();
932933
_DataType* result = result_ptr.get_ptr();
933934

@@ -1086,8 +1087,8 @@ DPCTLSyclEventRef dpnp_triu_c(DPCTLSyclQueueRef q_ref,
10861087

10871088
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
10881089

1089-
DPNPC_ptr_adapter<_DataType> input1_ptr(array_in, input_size, true);
1090-
DPNPC_ptr_adapter<_DataType> result_ptr(result1, res_size, true, true);
1090+
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref,array_in, input_size, true);
1091+
DPNPC_ptr_adapter<_DataType> result_ptr(q_ref,result1, res_size, true, true);
10911092
_DataType* array_m = input1_ptr.get_ptr();
10921093
_DataType* result = result_ptr.get_ptr();
10931094

dpnp/backend/kernels/dpnp_krnl_bitwise.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ DPCTLSyclEventRef dpnp_invert_c(DPCTLSyclQueueRef q_ref,
4949
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
5050
sycl::event event;
5151

52-
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, size);
52+
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, size);
5353
_DataType* array1 = input1_ptr.get_ptr();
5454
_DataType* result = reinterpret_cast<_DataType*>(result1);
5555

@@ -145,16 +145,16 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
145145
\
146146
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref)); \
147147
\
148-
DPNPC_ptr_adapter<_DataType> input1_ptr(input1_in, input1_size); \
149-
DPNPC_ptr_adapter<shape_elem_type> input1_shape_ptr(input1_shape, input1_ndim, true); \
150-
DPNPC_ptr_adapter<shape_elem_type> input1_strides_ptr(input1_strides, input1_ndim, true); \
148+
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, input1_in, input1_size); \
149+
DPNPC_ptr_adapter<shape_elem_type> input1_shape_ptr(q_ref, input1_shape, input1_ndim, true); \
150+
DPNPC_ptr_adapter<shape_elem_type> input1_strides_ptr(q_ref, input1_strides, input1_ndim, true); \
151151
\
152-
DPNPC_ptr_adapter<_DataType> input2_ptr(input2_in, input2_size); \
153-
DPNPC_ptr_adapter<shape_elem_type> input2_shape_ptr(input2_shape, input2_ndim, true); \
154-
DPNPC_ptr_adapter<shape_elem_type> input2_strides_ptr(input2_strides, input2_ndim, true); \
152+
DPNPC_ptr_adapter<_DataType> input2_ptr(q_ref, input2_in, input2_size); \
153+
DPNPC_ptr_adapter<shape_elem_type> input2_shape_ptr(q_ref, input2_shape, input2_ndim, true); \
154+
DPNPC_ptr_adapter<shape_elem_type> input2_strides_ptr(q_ref, input2_strides, input2_ndim, true); \
155155
\
156-
DPNPC_ptr_adapter<_DataType> result_ptr(result_out, result_size, false, true); \
157-
DPNPC_ptr_adapter<shape_elem_type> result_strides_ptr(result_strides, result_ndim); \
156+
DPNPC_ptr_adapter<_DataType> result_ptr(q_ref, result_out, result_size, false, true); \
157+
DPNPC_ptr_adapter<shape_elem_type> result_strides_ptr(q_ref, result_strides, result_ndim); \
158158
\
159159
_DataType* input1_data = input1_ptr.get_ptr(); \
160160
shape_elem_type* input1_shape_data = input1_shape_ptr.get_ptr(); \

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ DPCTLSyclEventRef dpnp_astype_c(DPCTLSyclQueueRef q_ref,
5555
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
5656
sycl::event event;
5757

58-
DPNPC_ptr_adapter<_DataType> input1_ptr(array1_in, size);
58+
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, size);
5959
const _DataType* array_in = input1_ptr.get_ptr();
6060
_ResultType* result = reinterpret_cast<_ResultType*>(result1);
6161

@@ -217,9 +217,8 @@ DPCTLSyclEventRef dpnp_dot_c(DPCTLSyclQueueRef q_ref,
217217
DPCTLSyclEventRef event_ref = nullptr;
218218
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
219219

220-
DPNPC_ptr_adapter<_DataType_input1> input1_ptr(input1_in, input1_size);
221-
DPNPC_ptr_adapter<_DataType_input2> input2_ptr(input2_in, input2_size);
222-
220+
DPNPC_ptr_adapter<_DataType_input1> input1_ptr(q_ref, input1_in, input1_size);
221+
DPNPC_ptr_adapter<_DataType_input2> input2_ptr(q_ref, input2_in, input2_size);
223222
_DataType_input1* input1 = input1_ptr.get_ptr();
224223
_DataType_input2* input2 = input2_ptr.get_ptr();
225224
_DataType_output* result = reinterpret_cast<_DataType_output*>(result_out);
@@ -516,9 +515,9 @@ DPCTLSyclEventRef dpnp_eig_c(DPCTLSyclQueueRef q_ref,
516515
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
517516
sycl::event event;
518517

519-
DPNPC_ptr_adapter<_DataType> input1_ptr(array_in, size * size, true);
520-
DPNPC_ptr_adapter<_ResultType> result1_ptr(result1, size, true, true);
521-
DPNPC_ptr_adapter<_ResultType> result2_ptr(result2, size * size, true, true);
518+
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array_in, size * size, true);
519+
DPNPC_ptr_adapter<_ResultType> result1_ptr(q_ref, result1, size, true, true);
520+
DPNPC_ptr_adapter<_ResultType> result2_ptr(q_ref, result2, size * size, true, true);
522521
const _DataType* array = input1_ptr.get_ptr();
523522
_ResultType* result_val = result1_ptr.get_ptr();
524523
_ResultType* result_vec = result2_ptr.get_ptr();
@@ -627,8 +626,8 @@ DPCTLSyclEventRef dpnp_eigvals_c(DPCTLSyclQueueRef q_ref,
627626
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
628627
sycl::event event;
629628

630-
DPNPC_ptr_adapter<_DataType> input1_ptr(array_in, size * size, true);
631-
DPNPC_ptr_adapter<_ResultType> result1_ptr(result1, size, true, true);
629+
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array_in, size * size, true);
630+
DPNPC_ptr_adapter<_ResultType> result1_ptr(q_ref, result1, size, true, true);
632631
const _DataType* array = input1_ptr.get_ptr();
633632
_ResultType* result_val = result1_ptr.get_ptr();
634633

@@ -716,8 +715,8 @@ DPCTLSyclEventRef dpnp_initval_c(DPCTLSyclQueueRef q_ref,
716715

717716
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
718717

719-
DPNPC_ptr_adapter<_DataType> result1_ptr(result1, size);
720-
DPNPC_ptr_adapter<_DataType> value_ptr(value, 1);
718+
DPNPC_ptr_adapter<_DataType> result1_ptr(q_ref, result1, size);
719+
DPNPC_ptr_adapter<_DataType> value_ptr(q_ref, value, 1);
721720
_DataType* result = result1_ptr.get_ptr();
722721
_DataType* val = value_ptr.get_ptr();
723722

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@
7070
\
7171
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref)); \
7272
\
73-
DPNPC_ptr_adapter<_DataType_input> input1_ptr(input1_in, input1_size); \
74-
DPNPC_ptr_adapter<shape_elem_type> input1_shape_ptr(input1_shape, input1_ndim, true); \
75-
DPNPC_ptr_adapter<shape_elem_type> input1_strides_ptr(input1_strides, input1_ndim, true); \
73+
DPNPC_ptr_adapter<_DataType_input> input1_ptr(q_ref, input1_in, input1_size); \
74+
DPNPC_ptr_adapter<shape_elem_type> input1_shape_ptr(q_ref, input1_shape, input1_ndim, true); \
75+
DPNPC_ptr_adapter<shape_elem_type> input1_strides_ptr(q_ref, input1_strides, input1_ndim, true); \
7676
\
77-
DPNPC_ptr_adapter<_DataType_output> result_ptr(result_out, result_size, false, true); \
78-
DPNPC_ptr_adapter<shape_elem_type> result_strides_ptr(result_strides, result_ndim); \
77+
DPNPC_ptr_adapter<_DataType_output> result_ptr(q_ref, result_out, result_size, false, true); \
78+
DPNPC_ptr_adapter<shape_elem_type> result_strides_ptr(q_ref, result_strides, result_ndim); \
7979
\
8080
_DataType_input* input1_data = input1_ptr.get_ptr(); \
8181
shape_elem_type* input1_shape_data = input1_shape_ptr.get_ptr(); \
@@ -575,12 +575,11 @@ static void func_map_init_elemwise_1arg_2type(func_map_t& fmap)
575575
\
576576
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref)); \
577577
\
578-
DPNPC_ptr_adapter<_DataType> input1_ptr(input1_in, input1_size); \
579-
DPNPC_ptr_adapter<shape_elem_type> input1_shape_ptr(input1_shape, input1_ndim, true); \
580-
DPNPC_ptr_adapter<shape_elem_type> input1_strides_ptr(input1_strides, input1_ndim, true); \
581-
\
582-
DPNPC_ptr_adapter<_DataType> result_ptr(result_out, result_size, false, true); \
583-
DPNPC_ptr_adapter<shape_elem_type> result_strides_ptr(result_strides, result_ndim); \
578+
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, input1_in, input1_size); \
579+
DPNPC_ptr_adapter<shape_elem_type> input1_shape_ptr(q_ref, input1_shape, input1_ndim, true); \
580+
DPNPC_ptr_adapter<shape_elem_type> input1_strides_ptr(q_ref, input1_strides, input1_ndim, true); \
581+
DPNPC_ptr_adapter<_DataType> result_ptr(q_ref, result_out, result_size, false, true); \
582+
DPNPC_ptr_adapter<shape_elem_type> result_strides_ptr(q_ref, result_strides, result_ndim); \
584583
\
585584
_DataType* input1_data = input1_ptr.get_ptr(); \
586585
shape_elem_type* input1_shape_data = input1_shape_ptr.get_ptr(); \
@@ -860,17 +859,16 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
860859
\
861860
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref)); \
862861
\
863-
DPNPC_ptr_adapter<_DataType_input1> input1_ptr(input1_in, input1_size); \
864-
DPNPC_ptr_adapter<shape_elem_type> input1_shape_ptr(input1_shape, input1_ndim, true); \
865-
DPNPC_ptr_adapter<shape_elem_type> input1_strides_ptr(input1_strides, input1_ndim, true); \
866-
\
867-
DPNPC_ptr_adapter<_DataType_input2> input2_ptr(input2_in, input2_size); \
868-
DPNPC_ptr_adapter<shape_elem_type> input2_shape_ptr(input2_shape, input2_ndim, true); \
869-
DPNPC_ptr_adapter<shape_elem_type> input2_strides_ptr(input2_strides, input2_ndim, true); \
862+
DPNPC_ptr_adapter<_DataType_input1> input1_ptr(q_ref, input1_in, input1_size); \
863+
DPNPC_ptr_adapter<shape_elem_type> input1_shape_ptr(q_ref, input1_shape, input1_ndim, true); \
864+
DPNPC_ptr_adapter<shape_elem_type> input1_strides_ptr(q_ref, input1_strides, input1_ndim, true); \
865+
DPNPC_ptr_adapter<_DataType_input2> input2_ptr(q_ref, input2_in, input2_size); \
866+
DPNPC_ptr_adapter<shape_elem_type> input2_shape_ptr(q_ref, input2_shape, input2_ndim, true); \
867+
DPNPC_ptr_adapter<shape_elem_type> input2_strides_ptr(q_ref, input2_strides, input2_ndim, true); \
870868
\
871-
DPNPC_ptr_adapter<_DataType_output> result_ptr(result_out, result_size, false, true); \
872-
DPNPC_ptr_adapter<shape_elem_type> result_shape_ptr(result_shape, result_ndim); \
873-
DPNPC_ptr_adapter<shape_elem_type> result_strides_ptr(result_strides, result_ndim); \
869+
DPNPC_ptr_adapter<_DataType_output> result_ptr(q_ref, result_out, result_size, false, true); \
870+
DPNPC_ptr_adapter<shape_elem_type> result_shape_ptr(q_ref, result_shape, result_ndim); \
871+
DPNPC_ptr_adapter<shape_elem_type> result_strides_ptr(q_ref, result_strides, result_ndim); \
874872
\
875873
_DataType_input1* input1_data = input1_ptr.get_ptr(); \
876874
shape_elem_type* input1_shape_data = input1_shape_ptr.get_ptr(); \

0 commit comments

Comments
 (0)