@@ -58,6 +58,8 @@ DPCTLSyclEventRef dpnp_arange_c(DPCTLSyclQueueRef q_ref,
58
58
sycl::queue q = *(reinterpret_cast <sycl::queue*>(q_ref));
59
59
sycl::event event;
60
60
61
+ validate_type_for_device<_DataType>(q);
62
+
61
63
_DataType* result = reinterpret_cast <_DataType*>(result1);
62
64
63
65
sycl::range<1 > gws (size);
@@ -72,7 +74,6 @@ DPCTLSyclEventRef dpnp_arange_c(DPCTLSyclQueueRef q_ref,
72
74
};
73
75
74
76
event = q.submit (kernel_func);
75
-
76
77
event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
77
78
78
79
return DPCTLEvent_Copy (event_ref);
@@ -144,6 +145,8 @@ DPCTLSyclEventRef dpnp_diag_c(DPCTLSyclQueueRef q_ref,
144
145
DPCTLSyclEventRef event_ref = nullptr ;
145
146
sycl::queue q = *(reinterpret_cast <sycl::queue*>(q_ref));
146
147
148
+ validate_type_for_device<_DataType>(q);
149
+
147
150
const size_t input1_size = std::accumulate (shape, shape + ndim, 1 , std::multiplies<shape_elem_type>());
148
151
const size_t result_size = std::accumulate (res_shape, res_shape + res_ndim, 1 , std::multiplies<shape_elem_type>());
149
152
DPNPC_ptr_adapter<_DataType> input1_ptr (q_ref,v_in, input1_size, true );
@@ -194,6 +197,7 @@ void dpnp_diag_c(void* v_in,
194
197
res_ndim,
195
198
dep_event_vec_ref);
196
199
DPCTLEvent_WaitAndThrow (event_ref);
200
+ DPCTLEvent_Delete (event_ref);
197
201
}
198
202
199
203
template <typename _DataType>
@@ -240,6 +244,8 @@ DPCTLSyclEventRef dpnp_eye_c(DPCTLSyclQueueRef q_ref,
240
244
241
245
sycl::queue q = *(reinterpret_cast <sycl::queue*>(q_ref));
242
246
247
+ validate_type_for_device<_DataType>(q);
248
+
243
249
size_t result_size = res_shape[0 ] * res_shape[1 ];
244
250
245
251
DPNPC_ptr_adapter<_DataType> result_ptr (q_ref,result1, result_size, true , true );
@@ -280,6 +286,7 @@ void dpnp_eye_c(void* result1, int k, const shape_elem_type* res_shape)
280
286
res_shape,
281
287
dep_event_vec_ref);
282
288
DPCTLEvent_WaitAndThrow (event_ref);
289
+ DPCTLEvent_Delete (event_ref);
283
290
}
284
291
285
292
template <typename _DataType>
@@ -368,7 +375,9 @@ DPCTLSyclEventRef dpnp_identity_c(DPCTLSyclQueueRef q_ref,
368
375
sycl::queue q = *(reinterpret_cast <sycl::queue*>(q_ref));
369
376
sycl::event event;
370
377
371
- _DataType* result = reinterpret_cast <_DataType*>(result1);
378
+ validate_type_for_device<_DataType>(q);
379
+
380
+ _DataType* result = static_cast <_DataType *>(result1);
372
381
373
382
sycl::range<2 > gws (n, n);
374
383
auto kernel_parallel_for_func = [=](sycl::id<2 > global_id) {
@@ -382,10 +391,9 @@ DPCTLSyclEventRef dpnp_identity_c(DPCTLSyclQueueRef q_ref,
382
391
};
383
392
384
393
event = q.submit (kernel_func);
394
+ event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
385
395
386
- event.wait ();
387
-
388
- return event_ref;
396
+ return DPCTLEvent_Copy (event_ref);
389
397
}
390
398
391
399
template <typename _DataType>
@@ -398,6 +406,7 @@ void dpnp_identity_c(void* result1, const size_t n)
398
406
n,
399
407
dep_event_vec_ref);
400
408
DPCTLEvent_WaitAndThrow (event_ref);
409
+ DPCTLEvent_Delete (event_ref);
401
410
}
402
411
403
412
template <typename _DataType>
@@ -425,10 +434,11 @@ DPCTLSyclEventRef dpnp_ones_c(DPCTLSyclQueueRef q_ref,
425
434
426
435
DPCTLSyclEventRef event_ref = dpnp_initval_c<_DataType>(q_ref, result, fill_value, size, dep_event_vec_ref);
427
436
DPCTLEvent_WaitAndThrow (event_ref);
437
+ DPCTLEvent_Delete (event_ref);
428
438
429
439
sycl::free (fill_value, q);
430
440
431
- return event_ref ;
441
+ return nullptr ;
432
442
}
433
443
434
444
template <typename _DataType>
@@ -441,6 +451,7 @@ void dpnp_ones_c(void* result, size_t size)
441
451
size,
442
452
dep_event_vec_ref);
443
453
DPCTLEvent_WaitAndThrow (event_ref);
454
+ DPCTLEvent_Delete (event_ref);
444
455
}
445
456
446
457
template <typename _DataType>
@@ -471,6 +482,7 @@ void dpnp_ones_like_c(void* result, size_t size)
471
482
size,
472
483
dep_event_vec_ref);
473
484
DPCTLEvent_WaitAndThrow (event_ref);
485
+ DPCTLEvent_Delete (event_ref);
474
486
}
475
487
476
488
template <typename _DataType>
@@ -520,6 +532,8 @@ DPCTLSyclEventRef dpnp_ptp_c(DPCTLSyclQueueRef q_ref,
520
532
521
533
sycl::queue q = *(reinterpret_cast <sycl::queue*>(q_ref));
522
534
535
+ validate_type_for_device<_DataType>(q);
536
+
523
537
DPNPC_ptr_adapter<_DataType> input1_ptr (q_ref,input1_in, input_size, true );
524
538
DPNPC_ptr_adapter<_DataType> result_ptr (q_ref,result1_out, result_size, false , true );
525
539
_DataType* arr = input1_ptr.get_ptr ();
@@ -563,7 +577,7 @@ DPCTLSyclEventRef dpnp_ptp_c(DPCTLSyclQueueRef q_ref,
563
577
sycl::free (max_arr, q);
564
578
sycl::free (_strides, q);
565
579
566
- return event_ref;
580
+ return DPCTLEvent_Copy ( event_ref) ;
567
581
}
568
582
569
583
template <typename _DataType>
@@ -649,6 +663,9 @@ DPCTLSyclEventRef dpnp_vander_c(DPCTLSyclQueueRef q_ref,
649
663
650
664
sycl::queue q = *(reinterpret_cast <sycl::queue*>(q_ref));
651
665
666
+ validate_type_for_device<_DataType_input>(q);
667
+ validate_type_for_device<_DataType_output>(q);
668
+
652
669
DPNPC_ptr_adapter<_DataType_input> input1_ptr (q_ref,array1_in, size_in, true );
653
670
DPNPC_ptr_adapter<_DataType_output> result_ptr (q_ref,result1, size_in * N, true , true );
654
671
const _DataType_input* array_in = input1_ptr.get_ptr ();
@@ -693,7 +710,7 @@ DPCTLSyclEventRef dpnp_vander_c(DPCTLSyclQueueRef q_ref,
693
710
}
694
711
}
695
712
696
- return event_ref;
713
+ return DPCTLEvent_Copy ( event_ref) ;
697
714
}
698
715
699
716
template <typename _DataType_input, typename _DataType_output>
@@ -709,6 +726,7 @@ void dpnp_vander_c(const void* array1_in, void* result1, const size_t size_in, c
709
726
increasing,
710
727
dep_event_vec_ref);
711
728
DPCTLEvent_WaitAndThrow (event_ref);
729
+ DPCTLEvent_Delete (event_ref);
712
730
}
713
731
714
732
template <typename _DataType_input, typename _DataType_output>
@@ -757,10 +775,11 @@ DPCTLSyclEventRef dpnp_trace_c(DPCTLSyclQueueRef q_ref,
757
775
758
776
sycl::queue q = *(reinterpret_cast <sycl::queue*>(q_ref));
759
777
760
- DPNPC_ptr_adapter<_DataType> input1_ptr (q_ref,array1_in, size * last_dim);
778
+ validate_type_for_device<_DataType>(q);
779
+ validate_type_for_device<_ResultType>(q);
761
780
762
- const _DataType* input = input1_ptr. get_ptr ( );
763
- _ResultType* result = reinterpret_cast <_ResultType*>(result_in);
781
+ const _DataType* input = static_cast < const _DataType *>(array1_in );
782
+ _ResultType* result = static_cast <_ResultType *>(result_in);
764
783
765
784
sycl::range<1 > gws (size);
766
785
auto kernel_parallel_for_func = [=](auto index) {
@@ -780,7 +799,6 @@ DPCTLSyclEventRef dpnp_trace_c(DPCTLSyclQueueRef q_ref,
780
799
};
781
800
782
801
auto event = q.submit (kernel_func);
783
-
784
802
event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
785
803
786
804
return DPCTLEvent_Copy (event_ref);
@@ -798,6 +816,7 @@ void dpnp_trace_c(const void* array1_in, void* result_in, const shape_elem_type*
798
816
ndim,
799
817
dep_event_vec_ref);
800
818
DPCTLEvent_WaitAndThrow (event_ref);
819
+ DPCTLEvent_Delete (event_ref);
801
820
}
802
821
803
822
template <typename _DataType, typename _ResultType>
@@ -839,7 +858,9 @@ DPCTLSyclEventRef dpnp_tri_c(DPCTLSyclQueueRef q_ref,
839
858
840
859
sycl::queue q = *(reinterpret_cast <sycl::queue*>(q_ref));
841
860
842
- _DataType* result = reinterpret_cast <_DataType*>(result1);
861
+ validate_type_for_device<_DataType>(q);
862
+
863
+ _DataType* result = static_cast <_DataType *>(result1);
843
864
844
865
size_t idx = N * M;
845
866
sycl::range<1 > gws (idx);
@@ -867,7 +888,6 @@ DPCTLSyclEventRef dpnp_tri_c(DPCTLSyclQueueRef q_ref,
867
888
};
868
889
869
890
event = q.submit (kernel_func);
870
-
871
891
event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
872
892
873
893
return DPCTLEvent_Copy (event_ref);
@@ -885,6 +905,7 @@ void dpnp_tri_c(void* result1, const size_t N, const size_t M, const int k)
885
905
k,
886
906
dep_event_vec_ref);
887
907
DPCTLEvent_WaitAndThrow (event_ref);
908
+ DPCTLEvent_Delete (event_ref);
888
909
}
889
910
890
911
template <typename _DataType>
@@ -946,6 +967,8 @@ DPCTLSyclEventRef dpnp_tril_c(DPCTLSyclQueueRef q_ref,
946
967
947
968
sycl::queue q = *(reinterpret_cast <sycl::queue*>(q_ref));
948
969
970
+ validate_type_for_device<_DataType>(q);
971
+
949
972
DPNPC_ptr_adapter<_DataType> input1_ptr (q_ref,array_in, input_size, true );
950
973
DPNPC_ptr_adapter<_DataType> result_ptr (q_ref,result1, res_size, true , true );
951
974
_DataType* array_m = input1_ptr.get_ptr ();
@@ -1015,7 +1038,7 @@ DPCTLSyclEventRef dpnp_tril_c(DPCTLSyclQueueRef q_ref,
1015
1038
}
1016
1039
}
1017
1040
}
1018
- return event_ref;
1041
+ return DPCTLEvent_Copy ( event_ref) ;
1019
1042
}
1020
1043
1021
1044
template <typename _DataType>
@@ -1039,6 +1062,7 @@ void dpnp_tril_c(void* array_in,
1039
1062
res_ndim,
1040
1063
dep_event_vec_ref);
1041
1064
DPCTLEvent_WaitAndThrow (event_ref);
1065
+ DPCTLEvent_Delete (event_ref);
1042
1066
}
1043
1067
1044
1068
template <typename _DataType>
@@ -1106,6 +1130,8 @@ DPCTLSyclEventRef dpnp_triu_c(DPCTLSyclQueueRef q_ref,
1106
1130
1107
1131
sycl::queue q = *(reinterpret_cast <sycl::queue*>(q_ref));
1108
1132
1133
+ validate_type_for_device<_DataType>(q);
1134
+
1109
1135
DPNPC_ptr_adapter<_DataType> input1_ptr (q_ref,array_in, input_size, true );
1110
1136
DPNPC_ptr_adapter<_DataType> result_ptr (q_ref,result1, res_size, true , true );
1111
1137
_DataType* array_m = input1_ptr.get_ptr ();
@@ -1175,7 +1201,7 @@ DPCTLSyclEventRef dpnp_triu_c(DPCTLSyclQueueRef q_ref,
1175
1201
}
1176
1202
}
1177
1203
}
1178
- return event_ref;
1204
+ return DPCTLEvent_Copy ( event_ref) ;
1179
1205
}
1180
1206
1181
1207
template <typename _DataType>
@@ -1199,6 +1225,7 @@ void dpnp_triu_c(void* array_in,
1199
1225
res_ndim,
1200
1226
dep_event_vec_ref);
1201
1227
DPCTLEvent_WaitAndThrow (event_ref);
1228
+ DPCTLEvent_Delete (event_ref);
1202
1229
}
1203
1230
1204
1231
template <typename _DataType>
@@ -1234,10 +1261,11 @@ DPCTLSyclEventRef dpnp_zeros_c(DPCTLSyclQueueRef q_ref,
1234
1261
1235
1262
DPCTLSyclEventRef event_ref = dpnp_initval_c<_DataType>(q_ref, result, fill_value, size, dep_event_vec_ref);
1236
1263
DPCTLEvent_WaitAndThrow (event_ref);
1264
+ DPCTLEvent_Delete (event_ref);
1237
1265
1238
1266
sycl::free (fill_value, q);
1239
1267
1240
- return event_ref ;
1268
+ return nullptr ;
1241
1269
}
1242
1270
1243
1271
template <typename _DataType>
@@ -1337,6 +1365,8 @@ void func_map_init_arraycreation(func_map_t& fmap)
1337
1365
fmap[DPNPFuncName::DPNP_FN_IDENTITY_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void *)dpnp_identity_ext_c<float >};
1338
1366
fmap[DPNPFuncName::DPNP_FN_IDENTITY_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void *)dpnp_identity_ext_c<double >};
1339
1367
fmap[DPNPFuncName::DPNP_FN_IDENTITY_EXT][eft_BLN][eft_BLN] = {eft_BLN, (void *)dpnp_identity_ext_c<bool >};
1368
+ fmap[DPNPFuncName::DPNP_FN_IDENTITY_EXT][eft_C64][eft_C64] = {eft_C64,
1369
+ (void *)dpnp_identity_ext_c<std::complex<float >>};
1340
1370
fmap[DPNPFuncName::DPNP_FN_IDENTITY_EXT][eft_C128][eft_C128] = {eft_C128,
1341
1371
(void *)dpnp_identity_ext_c<std::complex<double >>};
1342
1372
@@ -1392,9 +1422,11 @@ void func_map_init_arraycreation(func_map_t& fmap)
1392
1422
1393
1423
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_INT][eft_INT] = {eft_LNG, (void *)dpnp_vander_ext_c<int32_t , int64_t >};
1394
1424
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void *)dpnp_vander_ext_c<int64_t , int64_t >};
1395
- fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_FLT][eft_FLT] = {eft_DBL , (void *)dpnp_vander_ext_c<float , double >};
1425
+ fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_FLT][eft_FLT] = {eft_FLT , (void *)dpnp_vander_ext_c<float , float >};
1396
1426
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void *)dpnp_vander_ext_c<double , double >};
1397
1427
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_BLN][eft_BLN] = {eft_LNG, (void *)dpnp_vander_ext_c<bool , int64_t >};
1428
+ fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_C64][eft_C64] = {
1429
+ eft_C64, (void *)dpnp_vander_ext_c<std::complex<float >, std::complex<float >>};
1398
1430
fmap[DPNPFuncName::DPNP_FN_VANDER_EXT][eft_C128][eft_C128] = {
1399
1431
eft_C128, (void *)dpnp_vander_ext_c<std::complex<double >, std::complex<double >>};
1400
1432
0 commit comments