Skip to content

Commit f525309

Browse files
authored
Pass queue to iterator (#1130)
* Pass queue to iterator
1 parent 99a9236 commit f525309

File tree

6 files changed

+121
-58
lines changed

6 files changed

+121
-58
lines changed

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -906,17 +906,19 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
906906
{ \
907907
DPNPC_id<_DataType_input1>* input1_it; \
908908
const size_t input1_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input1>); \
909-
input1_it = reinterpret_cast<DPNPC_id<_DataType_input1>*>(sycl::malloc_shared(input1_it_size_in_bytes, q));\
909+
input1_it = reinterpret_cast<DPNPC_id<_DataType_input1>*>(dpnp_memory_alloc_c(q_ref, \
910+
input1_it_size_in_bytes)); \
910911
new (input1_it) \
911-
DPNPC_id<_DataType_input1>(input1_data, input1_shape_data, input1_strides_data, input1_ndim); \
912+
DPNPC_id<_DataType_input1>(q_ref, input1_data, input1_shape_data, input1_strides_data, input1_ndim); \
912913
\
913914
input1_it->broadcast_to_shape(result_shape_data, result_ndim); \
914915
\
915916
DPNPC_id<_DataType_input2>* input2_it; \
916917
const size_t input2_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input2>); \
917-
input2_it = reinterpret_cast<DPNPC_id<_DataType_input2>*>(sycl::malloc_shared(input2_it_size_in_bytes, q));\
918+
input2_it = reinterpret_cast<DPNPC_id<_DataType_input2>*>(dpnp_memory_alloc_c(q_ref, \
919+
input2_it_size_in_bytes)); \
918920
new (input2_it) \
919-
DPNPC_id<_DataType_input2>(input2_data, input2_shape_data, input2_strides_data, input2_ndim); \
921+
DPNPC_id<_DataType_input2>(q_ref, input2_data, input2_shape_data, input2_strides_data, input2_ndim); \
920922
\
921923
input2_it->broadcast_to_shape(result_shape_data, result_ndim); \
922924
\

dpnp/backend/kernels/dpnp_krnl_mathematical.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -587,15 +587,15 @@ DPCTLSyclEventRef dpnp_floor_divide_c(DPCTLSyclQueueRef q_ref,
587587

588588
DPNPC_id<_DataType_input1>* input1_it;
589589
const size_t input1_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input1>);
590-
input1_it = reinterpret_cast<DPNPC_id<_DataType_input1>*>(sycl::malloc_shared(input1_it_size_in_bytes, q));
591-
new (input1_it) DPNPC_id<_DataType_input1>(input1_data, input1_shape, input1_shape_ndim);
590+
input1_it = reinterpret_cast<DPNPC_id<_DataType_input1>*>(dpnp_memory_alloc_c(q_ref, input1_it_size_in_bytes));
591+
new (input1_it) DPNPC_id<_DataType_input1>(q_ref, input1_data, input1_shape, input1_shape_ndim);
592592

593593
input1_it->broadcast_to_shape(result_shape);
594594

595595
DPNPC_id<_DataType_input2>* input2_it;
596596
const size_t input2_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input2>);
597-
input2_it = reinterpret_cast<DPNPC_id<_DataType_input2>*>(sycl::malloc_shared(input2_it_size_in_bytes, q));
598-
new (input2_it) DPNPC_id<_DataType_input2>(input2_data, input2_shape, input2_shape_ndim);
597+
input2_it = reinterpret_cast<DPNPC_id<_DataType_input2>*>(dpnp_memory_alloc_c(q_ref, input2_it_size_in_bytes));
598+
new (input2_it) DPNPC_id<_DataType_input2>(q_ref, input2_data, input2_shape, input2_shape_ndim);
599599

600600
input2_it->broadcast_to_shape(result_shape);
601601

@@ -823,15 +823,15 @@ DPCTLSyclEventRef dpnp_remainder_c(DPCTLSyclQueueRef q_ref,
823823

824824
DPNPC_id<_DataType_input1>* input1_it;
825825
const size_t input1_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input1>);
826-
input1_it = reinterpret_cast<DPNPC_id<_DataType_input1>*>(sycl::malloc_shared(input1_it_size_in_bytes, q));
827-
new (input1_it) DPNPC_id<_DataType_input1>(input1_data, input1_shape, input1_shape_ndim);
826+
input1_it = reinterpret_cast<DPNPC_id<_DataType_input1>*>(dpnp_memory_alloc_c(q_ref, input1_it_size_in_bytes));
827+
new (input1_it) DPNPC_id<_DataType_input1>(q_ref, input1_data, input1_shape, input1_shape_ndim);
828828

829829
input1_it->broadcast_to_shape(result_shape);
830830

831831
DPNPC_id<_DataType_input2>* input2_it;
832832
const size_t input2_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input2>);
833-
input2_it = reinterpret_cast<DPNPC_id<_DataType_input2>*>(sycl::malloc_shared(input2_it_size_in_bytes, q));
834-
new (input2_it) DPNPC_id<_DataType_input2>(input2_data, input2_shape, input2_shape_ndim);
833+
input2_it = reinterpret_cast<DPNPC_id<_DataType_input2>*>(dpnp_memory_alloc_c(q_ref, input2_it_size_in_bytes));
834+
new (input2_it) DPNPC_id<_DataType_input2>(q_ref, input2_data, input2_shape, input2_shape_ndim);
835835

836836
input2_it->broadcast_to_shape(result_shape);
837837

dpnp/backend/kernels/dpnp_krnl_reduction.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ DPCTLSyclEventRef dpnp_sum_c(DPCTLSyclQueueRef q_ref,
119119
}
120120
}
121121

122-
DPNPC_id<_DataType_input> input_it(input, input_shape, input_shape_ndim);
122+
DPNPC_id<_DataType_input> input_it(q_ref, input, input_shape, input_shape_ndim);
123123
input_it.set_axes(axes, axes_ndim);
124124

125125
const size_t output_size = input_it.get_output_size();
@@ -235,7 +235,7 @@ DPCTLSyclEventRef dpnp_prod_c(DPCTLSyclQueueRef q_ref,
235235
return event_ref;
236236
}
237237

238-
DPNPC_id<_DataType_input> input_it(input, input_shape, input_shape_ndim);
238+
DPNPC_id<_DataType_input> input_it(q_ref, input, input_shape, input_shape_ndim);
239239
input_it.set_axes(axes, axes_ndim);
240240

241241
const size_t output_size = input_it.get_output_size();

dpnp/backend/src/dpnp_iterator.hpp

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,20 @@ class DPNPC_id final
201201
using reference = value_type&;
202202
using size_type = shape_elem_type;
203203

204-
DPNPC_id(pointer __ptr, const size_type* __shape, const size_type __shape_size)
204+
DPNPC_id(DPCTLSyclQueueRef q_ref, pointer __ptr, const size_type* __shape, const size_type __shape_size)
205205
{
206+
queue_ref = q_ref;
206207
std::vector<size_type> shape(__shape, __shape + __shape_size);
207208
init_container(__ptr, shape);
208209
}
209210

210-
DPNPC_id(pointer __ptr, const size_type* __shape, const size_type* __strides, const size_type __ndim)
211+
DPNPC_id(DPCTLSyclQueueRef q_ref,
212+
pointer __ptr,
213+
const size_type* __shape,
214+
const size_type* __strides,
215+
const size_type __ndim)
211216
{
217+
queue_ref = q_ref;
212218
std::vector<size_type> shape(__shape, __shape + __ndim);
213219
std::vector<size_type> strides(__strides, __strides + __ndim);
214220
init_container(__ptr, shape, strides);
@@ -223,12 +229,14 @@ class DPNPC_id final
223229
*
224230
* @note this function is designed for non-SYCL environment execution
225231
*
232+
* @param [in] q_ref Reference to SYCL queue.
226233
* @param [in] __ptr Pointer to input data. Used to get values only.
227234
* @param [in] __shape Shape of data provided by @ref __ptr.
228235
* Empty container means scalar value pointed by @ref __ptr.
229236
*/
230-
DPNPC_id(pointer __ptr, const std::vector<size_type>& __shape)
237+
DPNPC_id(DPCTLSyclQueueRef q_ref, pointer __ptr, const std::vector<size_type>& __shape)
231238
{
239+
queue_ref = q_ref;
232240
init_container(__ptr, __shape);
233241
}
234242

@@ -296,7 +304,7 @@ class DPNPC_id final
296304

297305
output_shape_size = __shape.size();
298306
const size_type output_shape_size_in_bytes = output_shape_size * sizeof(size_type);
299-
output_shape = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(output_shape_size_in_bytes));
307+
output_shape = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(queue_ref, output_shape_size_in_bytes));
300308

301309
for (int irit = input_shape_size - 1, orit = output_shape_size - 1; orit >= 0; --irit, --orit)
302310
{
@@ -311,13 +319,15 @@ class DPNPC_id final
311319

312320
broadcast_axes_size = valid_axes.size();
313321
const size_type broadcast_axes_size_in_bytes = broadcast_axes_size * sizeof(size_type);
314-
broadcast_axes = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(broadcast_axes_size_in_bytes));
322+
broadcast_axes = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(queue_ref,
323+
broadcast_axes_size_in_bytes));
315324
std::copy(valid_axes.begin(), valid_axes.end(), broadcast_axes);
316325

317326
output_size = std::accumulate(
318327
output_shape, output_shape + output_shape_size, size_type(1), std::multiplies<size_type>());
319328

320-
output_shape_strides = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(output_shape_size_in_bytes));
329+
output_shape_strides = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(queue_ref,
330+
output_shape_size_in_bytes));
321331
get_shape_offsets_inkernel<size_type>(output_shape, output_shape_size, output_shape_strides);
322332

323333
iteration_size = 1;
@@ -392,7 +402,7 @@ class DPNPC_id final
392402
const size_type iteration_shape_size_in_bytes = iteration_shape_size * sizeof(size_type);
393403
std::vector<size_type> iteration_shape;
394404

395-
output_shape = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(output_shape_size_in_bytes));
405+
output_shape = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(queue_ref, output_shape_size_in_bytes));
396406
size_type* output_shape_it = output_shape;
397407
for (size_type i = 0; i < input_shape_size; ++i)
398408
{
@@ -406,7 +416,8 @@ class DPNPC_id final
406416
output_size = std::accumulate(
407417
output_shape, output_shape + output_shape_size, size_type(1), std::multiplies<size_type>());
408418

409-
output_shape_strides = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(output_shape_size_in_bytes));
419+
output_shape_strides = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(queue_ref,
420+
output_shape_size_in_bytes));
410421
get_shape_offsets_inkernel<size_type>(output_shape, output_shape_size, output_shape_strides);
411422

412423
iteration_size = 1;
@@ -418,11 +429,13 @@ class DPNPC_id final
418429
iteration_size *= axis_dim;
419430
}
420431

421-
iteration_shape_strides = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(iteration_shape_size_in_bytes));
432+
iteration_shape_strides = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(queue_ref,
433+
iteration_shape_size_in_bytes));
422434
get_shape_offsets_inkernel<size_type>(
423435
iteration_shape.data(), iteration_shape.size(), iteration_shape_strides);
424436

425-
axes_shape_strides = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(iteration_shape_size_in_bytes));
437+
axes_shape_strides = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(queue_ref,
438+
iteration_shape_size_in_bytes));
426439
for (size_t i = 0; i < iteration_shape_size; ++i)
427440
{
428441
axes_shape_strides[i] = input_shape_strides[axes[i]];
@@ -490,11 +503,12 @@ class DPNPC_id final
490503
}
491504

492505
input_shape_size = __shape.size();
493-
input_shape = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(input_shape_size * sizeof(size_type)));
506+
input_shape = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(queue_ref,
507+
input_shape_size * sizeof(size_type)));
494508
std::copy(__shape.begin(), __shape.end(), input_shape);
495509

496510
input_shape_strides =
497-
reinterpret_cast<size_type*>(dpnp_memory_alloc_c(input_shape_size * sizeof(size_type)));
511+
reinterpret_cast<size_type*>(dpnp_memory_alloc_c(queue_ref, input_shape_size * sizeof(size_type)));
498512
get_shape_offsets_inkernel<size_type>(input_shape, input_shape_size, input_shape_strides);
499513
}
500514
iteration_size = input_size;
@@ -525,11 +539,12 @@ class DPNPC_id final
525539
}
526540

527541
input_shape_size = __shape.size();
528-
input_shape = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(input_shape_size * sizeof(size_type)));
542+
input_shape = reinterpret_cast<size_type*>(dpnp_memory_alloc_c(queue_ref,
543+
input_shape_size * sizeof(size_type)));
529544
std::copy(__shape.begin(), __shape.end(), input_shape);
530545

531546
input_shape_strides =
532-
reinterpret_cast<size_type*>(dpnp_memory_alloc_c(input_shape_size * sizeof(size_type)));
547+
reinterpret_cast<size_type*>(dpnp_memory_alloc_c(queue_ref, input_shape_size * sizeof(size_type)));
533548
std::copy(__strides.begin(), __strides.end(), input_shape_strides);
534549
}
535550
iteration_size = input_size;
@@ -583,23 +598,23 @@ class DPNPC_id final
583598
void free_axes_memory()
584599
{
585600
axes.clear();
586-
dpnp_memory_free_c(axes_shape_strides);
601+
dpnp_memory_free_c(queue_ref, axes_shape_strides);
587602
axes_shape_strides = nullptr;
588603
}
589604

590605
void free_broadcast_axes_memory()
591606
{
592607
broadcast_axes_size = size_type{};
593-
dpnp_memory_free_c(broadcast_axes);
608+
dpnp_memory_free_c(queue_ref, broadcast_axes);
594609
broadcast_axes = nullptr;
595610
}
596611

597612
void free_input_memory()
598613
{
599614
input_size = size_type{};
600615
input_shape_size = size_type{};
601-
dpnp_memory_free_c(input_shape);
602-
dpnp_memory_free_c(input_shape_strides);
616+
dpnp_memory_free_c(queue_ref, input_shape);
617+
dpnp_memory_free_c(queue_ref, input_shape_strides);
603618
input_shape = nullptr;
604619
input_shape_strides = nullptr;
605620
}
@@ -608,16 +623,16 @@ class DPNPC_id final
608623
{
609624
iteration_size = size_type{};
610625
iteration_shape_size = size_type{};
611-
dpnp_memory_free_c(iteration_shape_strides);
626+
dpnp_memory_free_c(queue_ref, iteration_shape_strides);
612627
iteration_shape_strides = nullptr;
613628
}
614629

615630
void free_output_memory()
616631
{
617632
output_size = size_type{};
618633
output_shape_size = size_type{};
619-
dpnp_memory_free_c(output_shape);
620-
dpnp_memory_free_c(output_shape_strides);
634+
dpnp_memory_free_c(queue_ref, output_shape);
635+
dpnp_memory_free_c(queue_ref, output_shape_strides);
621636
output_shape = nullptr;
622637
output_shape_strides = nullptr;
623638
}
@@ -631,6 +646,8 @@ class DPNPC_id final
631646
free_output_memory();
632647
}
633648

649+
DPCTLSyclQueueRef queue_ref = nullptr; /**< reference to SYCL queue */
650+
634651
pointer data = nullptr; /**< input array begin pointer */
635652
size_type input_size = size_type{}; /**< input array size */
636653
size_type* input_shape = nullptr; /**< input array shape */

dpnp/backend/tests/test_broadcast_iterator.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ TEST_P(IteratorBroadcasting, loop_broadcast)
6060
const IteratorParameters& param = GetParam();
6161
std::vector<data_type> input_data = get_input_data<data_type>(param.input_shape);
6262

63-
DPNPC_id<data_type> input(input_data.data(), param.input_shape);
63+
DPCTLSyclQueueRef q_ref = reinterpret_cast<DPCTLSyclQueueRef>(&DPNP_QUEUE);
64+
65+
DPNPC_id<data_type> input(q_ref, input_data.data(), param.input_shape);
6466
input.broadcast_to_shape(param.output_shape);
6567

6668
ASSERT_EQ(input.get_output_size(), param.result.size());
@@ -82,9 +84,11 @@ TEST_P(IteratorBroadcasting, sycl_broadcast)
8284
std::vector<data_type> input_data = get_input_data<data_type>(param.input_shape);
8385
data_type* shared_data = get_shared_data<data_type>(input_data);
8486

87+
DPCTLSyclQueueRef q_ref = reinterpret_cast<DPCTLSyclQueueRef>(&DPNP_QUEUE);
88+
8589
DPNPC_id<data_type>* input_it;
86-
input_it = reinterpret_cast<DPNPC_id<data_type>*>(dpnp_memory_alloc_c(sizeof(DPNPC_id<data_type>)));
87-
new (input_it) DPNPC_id<data_type>(shared_data, param.input_shape);
90+
input_it = reinterpret_cast<DPNPC_id<data_type>*>(dpnp_memory_alloc_c(q_ref, sizeof(DPNPC_id<data_type>)));
91+
new (input_it) DPNPC_id<data_type>(q_ref, shared_data, param.input_shape);
8892

8993
input_it->broadcast_to_shape(param.output_shape);
9094

0 commit comments

Comments
 (0)