Skip to content

Commit 7bc4392

Browse files
committed
Merge branch 'master' into multiply_by_scalar
2 parents cb31bbd + 51938b0 commit 7bc4392

12 files changed

+685
-300
lines changed

dpnp/backend/kernels/dpnp_krnl_arraycreation.cpp

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ DPCTLSyclEventRef dpnp_arange_c(DPCTLSyclQueueRef q_ref,
5858
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
5959
sycl::event event;
6060

61+
validate_type_for_device<_DataType>(q);
62+
6163
_DataType* result = reinterpret_cast<_DataType*>(result1);
6264

6365
sycl::range<1> gws(size);
@@ -72,7 +74,6 @@ DPCTLSyclEventRef dpnp_arange_c(DPCTLSyclQueueRef q_ref,
7274
};
7375

7476
event = q.submit(kernel_func);
75-
7677
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
7778

7879
return DPCTLEvent_Copy(event_ref);
@@ -144,6 +145,8 @@ DPCTLSyclEventRef dpnp_diag_c(DPCTLSyclQueueRef q_ref,
144145
DPCTLSyclEventRef event_ref = nullptr;
145146
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
146147

148+
validate_type_for_device<_DataType>(q);
149+
147150
const size_t input1_size = std::accumulate(shape, shape + ndim, 1, std::multiplies<shape_elem_type>());
148151
const size_t result_size = std::accumulate(res_shape, res_shape + res_ndim, 1, std::multiplies<shape_elem_type>());
149152
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref,v_in, input1_size, true);
@@ -194,6 +197,7 @@ void dpnp_diag_c(void* v_in,
194197
res_ndim,
195198
dep_event_vec_ref);
196199
DPCTLEvent_WaitAndThrow(event_ref);
200+
DPCTLEvent_Delete(event_ref);
197201
}
198202

199203
template <typename _DataType>
@@ -240,6 +244,8 @@ DPCTLSyclEventRef dpnp_eye_c(DPCTLSyclQueueRef q_ref,
240244

241245
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
242246

247+
validate_type_for_device<_DataType>(q);
248+
243249
size_t result_size = res_shape[0] * res_shape[1];
244250

245251
DPNPC_ptr_adapter<_DataType> result_ptr(q_ref,result1, result_size, true, true);
@@ -280,6 +286,7 @@ void dpnp_eye_c(void* result1, int k, const shape_elem_type* res_shape)
280286
res_shape,
281287
dep_event_vec_ref);
282288
DPCTLEvent_WaitAndThrow(event_ref);
289+
DPCTLEvent_Delete(event_ref);
283290
}
284291

285292
template <typename _DataType>
@@ -368,7 +375,9 @@ DPCTLSyclEventRef dpnp_identity_c(DPCTLSyclQueueRef q_ref,
368375
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
369376
sycl::event event;
370377

371-
_DataType* result = reinterpret_cast<_DataType*>(result1);
378+
validate_type_for_device<_DataType>(q);
379+
380+
_DataType* result = static_cast<_DataType *>(result1);
372381

373382
sycl::range<2> gws(n, n);
374383
auto kernel_parallel_for_func = [=](sycl::id<2> global_id) {
@@ -382,10 +391,9 @@ DPCTLSyclEventRef dpnp_identity_c(DPCTLSyclQueueRef q_ref,
382391
};
383392

384393
event = q.submit(kernel_func);
394+
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
385395

386-
event.wait();
387-
388-
return event_ref;
396+
return DPCTLEvent_Copy(event_ref);
389397
}
390398

391399
template <typename _DataType>
@@ -398,6 +406,7 @@ void dpnp_identity_c(void* result1, const size_t n)
398406
n,
399407
dep_event_vec_ref);
400408
DPCTLEvent_WaitAndThrow(event_ref);
409+
DPCTLEvent_Delete(event_ref);
401410
}
402411

403412
template <typename _DataType>
@@ -425,10 +434,11 @@ DPCTLSyclEventRef dpnp_ones_c(DPCTLSyclQueueRef q_ref,
425434

426435
DPCTLSyclEventRef event_ref = dpnp_initval_c<_DataType>(q_ref, result, fill_value, size, dep_event_vec_ref);
427436
DPCTLEvent_WaitAndThrow(event_ref);
437+
DPCTLEvent_Delete(event_ref);
428438

429439
sycl::free(fill_value, q);
430440

431-
return event_ref;
441+
return nullptr;
432442
}
433443

434444
template <typename _DataType>
@@ -441,6 +451,7 @@ void dpnp_ones_c(void* result, size_t size)
441451
size,
442452
dep_event_vec_ref);
443453
DPCTLEvent_WaitAndThrow(event_ref);
454+
DPCTLEvent_Delete(event_ref);
444455
}
445456

446457
template <typename _DataType>
@@ -471,6 +482,7 @@ void dpnp_ones_like_c(void* result, size_t size)
471482
size,
472483
dep_event_vec_ref);
473484
DPCTLEvent_WaitAndThrow(event_ref);
485+
DPCTLEvent_Delete(event_ref);
474486
}
475487

476488
template <typename _DataType>
@@ -520,6 +532,8 @@ DPCTLSyclEventRef dpnp_ptp_c(DPCTLSyclQueueRef q_ref,
520532

521533
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
522534

535+
validate_type_for_device<_DataType>(q);
536+
523537
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref,input1_in, input_size, true);
524538
DPNPC_ptr_adapter<_DataType> result_ptr(q_ref,result1_out, result_size, false, true);
525539
_DataType* arr = input1_ptr.get_ptr();
@@ -563,7 +577,7 @@ DPCTLSyclEventRef dpnp_ptp_c(DPCTLSyclQueueRef q_ref,
563577
sycl::free(max_arr, q);
564578
sycl::free(_strides, q);
565579

566-
return event_ref;
580+
return DPCTLEvent_Copy(event_ref);
567581
}
568582

569583
template <typename _DataType>
@@ -649,6 +663,9 @@ DPCTLSyclEventRef dpnp_vander_c(DPCTLSyclQueueRef q_ref,
649663

650664
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
651665

666+
validate_type_for_device<_DataType_input>(q);
667+
validate_type_for_device<_DataType_output>(q);
668+
652669
DPNPC_ptr_adapter<_DataType_input> input1_ptr(q_ref,array1_in, size_in, true);
653670
DPNPC_ptr_adapter<_DataType_output> result_ptr(q_ref,result1, size_in * N, true, true);
654671
const _DataType_input* array_in = input1_ptr.get_ptr();
@@ -693,7 +710,7 @@ DPCTLSyclEventRef dpnp_vander_c(DPCTLSyclQueueRef q_ref,
693710
}
694711
}
695712

696-
return event_ref;
713+
return DPCTLEvent_Copy(event_ref);
697714
}
698715

699716
template <typename _DataType_input, typename _DataType_output>
@@ -709,6 +726,7 @@ void dpnp_vander_c(const void* array1_in, void* result1, const size_t size_in, c
709726
increasing,
710727
dep_event_vec_ref);
711728
DPCTLEvent_WaitAndThrow(event_ref);
729+
DPCTLEvent_Delete(event_ref);
712730
}
713731

714732
template <typename _DataType_input, typename _DataType_output>
@@ -757,10 +775,11 @@ DPCTLSyclEventRef dpnp_trace_c(DPCTLSyclQueueRef q_ref,
757775

758776
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
759777

760-
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref,array1_in, size * last_dim);
778+
validate_type_for_device<_DataType>(q);
779+
validate_type_for_device<_ResultType>(q);
761780

762-
const _DataType* input = input1_ptr.get_ptr();
763-
_ResultType* result = reinterpret_cast<_ResultType*>(result_in);
781+
const _DataType* input = static_cast<const _DataType *>(array1_in);
782+
_ResultType* result = static_cast<_ResultType *>(result_in);
764783

765784
sycl::range<1> gws(size);
766785
auto kernel_parallel_for_func = [=](auto index) {
@@ -780,7 +799,6 @@ DPCTLSyclEventRef dpnp_trace_c(DPCTLSyclQueueRef q_ref,
780799
};
781800

782801
auto event = q.submit(kernel_func);
783-
784802
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
785803

786804
return DPCTLEvent_Copy(event_ref);
@@ -798,6 +816,7 @@ void dpnp_trace_c(const void* array1_in, void* result_in, const shape_elem_type*
798816
ndim,
799817
dep_event_vec_ref);
800818
DPCTLEvent_WaitAndThrow(event_ref);
819+
DPCTLEvent_Delete(event_ref);
801820
}
802821

803822
template <typename _DataType, typename _ResultType>
@@ -839,7 +858,9 @@ DPCTLSyclEventRef dpnp_tri_c(DPCTLSyclQueueRef q_ref,
839858

840859
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
841860

842-
_DataType* result = reinterpret_cast<_DataType*>(result1);
861+
validate_type_for_device<_DataType>(q);
862+
863+
_DataType* result = static_cast<_DataType *>(result1);
843864

844865
size_t idx = N * M;
845866
sycl::range<1> gws(idx);
@@ -867,7 +888,6 @@ DPCTLSyclEventRef dpnp_tri_c(DPCTLSyclQueueRef q_ref,
867888
};
868889

869890
event = q.submit(kernel_func);
870-
871891
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
872892

873893
return DPCTLEvent_Copy(event_ref);
@@ -885,6 +905,7 @@ void dpnp_tri_c(void* result1, const size_t N, const size_t M, const int k)
885905
k,
886906
dep_event_vec_ref);
887907
DPCTLEvent_WaitAndThrow(event_ref);
908+
DPCTLEvent_Delete(event_ref);
888909
}
889910

890911
template <typename _DataType>
@@ -946,6 +967,8 @@ DPCTLSyclEventRef dpnp_tril_c(DPCTLSyclQueueRef q_ref,
946967

947968
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
948969

970+
validate_type_for_device<_DataType>(q);
971+
949972
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref,array_in, input_size, true);
950973
DPNPC_ptr_adapter<_DataType> result_ptr(q_ref,result1, res_size, true, true);
951974
_DataType* array_m = input1_ptr.get_ptr();
@@ -1015,7 +1038,7 @@ DPCTLSyclEventRef dpnp_tril_c(DPCTLSyclQueueRef q_ref,
10151038
}
10161039
}
10171040
}
1018-
return event_ref;
1041+
return DPCTLEvent_Copy(event_ref);
10191042
}
10201043

10211044
template <typename _DataType>
@@ -1039,6 +1062,7 @@ void dpnp_tril_c(void* array_in,
10391062
res_ndim,
10401063
dep_event_vec_ref);
10411064
DPCTLEvent_WaitAndThrow(event_ref);
1065+
DPCTLEvent_Delete(event_ref);
10421066
}
10431067

10441068
template <typename _DataType>
@@ -1106,6 +1130,8 @@ DPCTLSyclEventRef dpnp_triu_c(DPCTLSyclQueueRef q_ref,
11061130

11071131
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
11081132

1133+
validate_type_for_device<_DataType>(q);
1134+
11091135
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref,array_in, input_size, true);
11101136
DPNPC_ptr_adapter<_DataType> result_ptr(q_ref,result1, res_size, true, true);
11111137
_DataType* array_m = input1_ptr.get_ptr();
@@ -1175,7 +1201,7 @@ DPCTLSyclEventRef dpnp_triu_c(DPCTLSyclQueueRef q_ref,
11751201
}
11761202
}
11771203
}
1178-
return event_ref;
1204+
return DPCTLEvent_Copy(event_ref);
11791205
}
11801206

11811207
template <typename _DataType>
@@ -1199,6 +1225,7 @@ void dpnp_triu_c(void* array_in,
11991225
res_ndim,
12001226
dep_event_vec_ref);
12011227
DPCTLEvent_WaitAndThrow(event_ref);
1228+
DPCTLEvent_Delete(event_ref);
12021229
}
12031230

12041231
template <typename _DataType>
@@ -1234,10 +1261,11 @@ DPCTLSyclEventRef dpnp_zeros_c(DPCTLSyclQueueRef q_ref,
12341261

12351262
DPCTLSyclEventRef event_ref = dpnp_initval_c<_DataType>(q_ref, result, fill_value, size, dep_event_vec_ref);
12361263
DPCTLEvent_WaitAndThrow(event_ref);
1264+
DPCTLEvent_Delete(event_ref);
12371265

12381266
sycl::free(fill_value, q);
12391267

1240-
return event_ref;
1268+
return nullptr;
12411269
}
12421270

12431271
template <typename _DataType>
@@ -1337,6 +1365,8 @@ void func_map_init_arraycreation(func_map_t& fmap)
13371365
fmap[DPNPFuncName::DPNP_FN_IDENTITY_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_identity_ext_c<float>};
13381366
fmap[DPNPFuncName::DPNP_FN_IDENTITY_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_identity_ext_c<double>};
13391367
fmap[DPNPFuncName::DPNP_FN_IDENTITY_EXT][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_identity_ext_c<bool>};
1368+
fmap[DPNPFuncName::DPNP_FN_IDENTITY_EXT][eft_C64][eft_C64] = {eft_C64,
1369+
(void*)dpnp_identity_ext_c<std::complex<float>>};
13401370
fmap[DPNPFuncName::DPNP_FN_IDENTITY_EXT][eft_C128][eft_C128] = {eft_C128,
13411371
(void*)dpnp_identity_ext_c<std::complex<double>>};
13421372

@@ -1392,9 +1422,11 @@ void func_map_init_arraycreation(func_map_t& fmap)
13921422

13931423
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_INT][eft_INT] = {eft_LNG, (void*)dpnp_vander_ext_c<int32_t, int64_t>};
13941424
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_vander_ext_c<int64_t, int64_t>};
1395-
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_FLT][eft_FLT] = {eft_DBL, (void*)dpnp_vander_ext_c<float, double>};
1425+
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_vander_ext_c<float, float>};
13961426
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_vander_ext_c<double, double>};
13971427
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_BLN][eft_BLN] = {eft_LNG, (void*)dpnp_vander_ext_c<bool, int64_t>};
1428+
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_C64][eft_C64] = {
1429+
eft_C64, (void*)dpnp_vander_ext_c<std::complex<float>, std::complex<float>>};
13981430
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_C128][eft_C128] = {
13991431
eft_C128, (void*)dpnp_vander_ext_c<std::complex<double>, std::complex<double>>};
14001432

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ class dpnp_initval_c_kernel;
728728

729729
template <typename _DataType>
730730
DPCTLSyclEventRef dpnp_initval_c(DPCTLSyclQueueRef q_ref,
731-
void* result1,
731+
void* result,
732732
void* value,
733733
size_t size,
734734
const DPCTLEventVectorRef dep_event_vec_ref)
@@ -744,24 +744,11 @@ DPCTLSyclEventRef dpnp_initval_c(DPCTLSyclQueueRef q_ref,
744744
}
745745

746746
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
747+
_DataType val = *(static_cast<_DataType *>(value));
747748

748-
DPNPC_ptr_adapter<_DataType> result1_ptr(q_ref, result1, size);
749-
DPNPC_ptr_adapter<_DataType> value_ptr(q_ref, value, 1);
750-
_DataType* result = result1_ptr.get_ptr();
751-
_DataType* val = value_ptr.get_ptr();
752-
753-
sycl::range<1> gws(size);
754-
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) {
755-
const size_t idx = global_id[0];
756-
result[idx] = *val;
757-
};
758-
759-
auto kernel_func = [&](sycl::handler& cgh) {
760-
cgh.parallel_for<class dpnp_initval_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
761-
};
762-
763-
sycl::event event = q.submit(kernel_func);
749+
validate_type_for_device<_DataType>(q);
764750

751+
auto event = q.fill<_DataType>(result, val, size);
765752
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
766753

767754
return DPCTLEvent_Copy(event_ref);
@@ -1149,6 +1136,8 @@ void func_map_init_linalg(func_map_t& fmap)
11491136
fmap[DPNPFuncName::DPNP_FN_INITVAL_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_initval_ext_c<int64_t>};
11501137
fmap[DPNPFuncName::DPNP_FN_INITVAL_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_initval_ext_c<float>};
11511138
fmap[DPNPFuncName::DPNP_FN_INITVAL_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_initval_ext_c<double>};
1139+
fmap[DPNPFuncName::DPNP_FN_INITVAL_EXT][eft_C64][eft_C64] = {eft_C64,
1140+
(void*)dpnp_initval_ext_c<std::complex<float>>};
11521141
fmap[DPNPFuncName::DPNP_FN_INITVAL_EXT][eft_C128][eft_C128] = {eft_C128,
11531142
(void*)dpnp_initval_ext_c<std::complex<double>>};
11541143

0 commit comments

Comments
 (0)