42
42
43
43
namespace py = pybind11;
44
44
45
- static dpctl::tensor::detail::usm_ndarray_types array_types;
46
-
47
45
namespace
48
46
{
49
47
@@ -301,6 +299,7 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
301
299
int src_typenum = src.get_typenum ();
302
300
int dst_typenum = dst.get_typenum ();
303
301
302
+ auto array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
304
303
int src_type_id = array_types.typenum_to_lookup_id (src_typenum);
305
304
int dst_type_id = array_types.typenum_to_lookup_id (dst_typenum);
306
305
@@ -322,15 +321,16 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
322
321
throw py::value_error (" Arrays index overlapping segments of memory" );
323
322
}
324
323
325
- int src_flags = src.get_flags ();
326
- int dst_flags = dst.get_flags ();
324
+ bool is_src_c_contig = src.is_c_contiguous ();
325
+ bool is_src_f_contig = src.is_f_contiguous ();
326
+
327
+ bool is_dst_c_contig = dst.is_c_contiguous ();
328
+ bool is_dst_f_contig = dst.is_f_contiguous ();
327
329
328
330
// check for applicability of special cases:
329
331
// (same type && (both C-contiguous || both F-contiguous)
330
- bool both_c_contig = ((src_flags & USM_ARRAY_C_CONTIGUOUS) &&
331
- (dst_flags & USM_ARRAY_C_CONTIGUOUS));
332
- bool both_f_contig = ((src_flags & USM_ARRAY_F_CONTIGUOUS) &&
333
- (dst_flags & USM_ARRAY_F_CONTIGUOUS));
332
+ bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
333
+ bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
334
334
if (both_c_contig || both_f_contig) {
335
335
if (src_type_id == dst_type_id) {
336
336
@@ -360,12 +360,6 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
360
360
int nd = src_nd;
361
361
const py::ssize_t *shape = src_shape;
362
362
363
- bool is_src_c_contig = ((src_flags & USM_ARRAY_C_CONTIGUOUS) != 0 );
364
- bool is_src_f_contig = ((src_flags & USM_ARRAY_F_CONTIGUOUS) != 0 );
365
-
366
- bool is_dst_c_contig = ((dst_flags & USM_ARRAY_C_CONTIGUOUS) != 0 );
367
- bool is_dst_f_contig = ((dst_flags & USM_ARRAY_F_CONTIGUOUS) != 0 );
368
-
369
363
constexpr py::ssize_t src_itemsize = 1 ; // in elements
370
364
constexpr py::ssize_t dst_itemsize = 1 ; // in elements
371
365
@@ -550,6 +544,7 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
550
544
const py::ssize_t *src_shape = src.get_shape_raw ();
551
545
const py::ssize_t *dst_shape = dst.get_shape_raw ();
552
546
547
+ auto array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
553
548
int type_id = array_types.typenum_to_lookup_id (src_typenum);
554
549
555
550
auto fn = copy_for_reshape_generic_dispatch_vector[type_id];
@@ -576,14 +571,13 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
576
571
577
572
const py::ssize_t *src_strides = src.get_strides_raw ();
578
573
if (src_strides == nullptr ) {
579
- int src_flags = src.get_flags ();
580
- if (src_flags & USM_ARRAY_C_CONTIGUOUS) {
574
+ if (src.is_c_contiguous ()) {
581
575
const auto &src_contig_strides =
582
576
c_contiguous_strides (src_nd, src_shape);
583
577
std::copy (src_contig_strides.begin (), src_contig_strides.end (),
584
578
packed_host_shapes_strides_shp->begin () + src_nd);
585
579
}
586
- else if (src_flags & USM_ARRAY_F_CONTIGUOUS ) {
580
+ else if (src. is_f_contiguous () ) {
587
581
const auto &src_contig_strides =
588
582
f_contiguous_strides (src_nd, src_shape);
589
583
std::copy (src_contig_strides.begin (), src_contig_strides.end (),
@@ -602,15 +596,14 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
602
596
603
597
const py::ssize_t *dst_strides = dst.get_strides_raw ();
604
598
if (dst_strides == nullptr ) {
605
- int dst_flags = dst.get_flags ();
606
- if (dst_flags & USM_ARRAY_C_CONTIGUOUS) {
599
+ if (dst.is_c_contiguous ()) {
607
600
const auto &dst_contig_strides =
608
601
c_contiguous_strides (dst_nd, dst_shape);
609
602
std::copy (dst_contig_strides.begin (), dst_contig_strides.end (),
610
603
packed_host_shapes_strides_shp->begin () + 2 * src_nd +
611
604
dst_nd);
612
605
}
613
- else if (dst_flags & USM_ARRAY_F_CONTIGUOUS ) {
606
+ else if (dst. is_f_contiguous () ) {
614
607
const auto &dst_contig_strides =
615
608
f_contiguous_strides (dst_nd, dst_shape);
616
609
std::copy (dst_contig_strides.begin (), dst_contig_strides.end (),
@@ -736,6 +729,7 @@ void copy_numpy_ndarray_into_usm_ndarray(
736
729
py::detail::array_descriptor_proxy (npy_src.dtype ().ptr ())->type_num ;
737
730
int dst_typenum = dst.get_typenum ();
738
731
732
+ auto array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
739
733
int src_type_id = array_types.typenum_to_lookup_id (src_typenum);
740
734
int dst_type_id = array_types.typenum_to_lookup_id (dst_typenum);
741
735
@@ -744,14 +738,13 @@ void copy_numpy_ndarray_into_usm_ndarray(
744
738
char *dst_data = dst.get_data ();
745
739
746
740
int src_flags = npy_src.flags ();
747
- int dst_flags = dst.get_flags ();
748
741
749
742
// check for applicability of special cases:
750
743
// (same type && (both C-contiguous || both F-contiguous)
751
- bool both_c_contig = ((src_flags & py::array::c_style) &&
752
- (dst_flags & USM_ARRAY_C_CONTIGUOUS ));
753
- bool both_f_contig = ((src_flags & py::array::f_style) &&
754
- (dst_flags & USM_ARRAY_F_CONTIGUOUS ));
744
+ bool both_c_contig =
745
+ ((src_flags & py::array::c_style) && dst. is_c_contiguous ( ));
746
+ bool both_f_contig =
747
+ ((src_flags & py::array::f_style) && dst. is_f_contiguous ( ));
755
748
if (both_c_contig || both_f_contig) {
756
749
if (src_type_id == dst_type_id) {
757
750
int src_elem_size = npy_src.itemsize ();
@@ -791,8 +784,8 @@ void copy_numpy_ndarray_into_usm_ndarray(
791
784
bool is_src_c_contig = ((src_flags & py::array::c_style) != 0 );
792
785
bool is_src_f_contig = ((src_flags & py::array::f_style) != 0 );
793
786
794
- bool is_dst_c_contig = ((dst_flags & USM_ARRAY_C_CONTIGUOUS) != 0 );
795
- bool is_dst_f_contig = ((dst_flags & USM_ARRAY_F_CONTIGUOUS) != 0 );
787
+ bool is_dst_c_contig = dst. is_c_contiguous ( );
788
+ bool is_dst_f_contig = dst. is_f_contiguous ( );
796
789
797
790
// all args except itemsizes and is_?_contig bools can be modified by
798
791
// reference
@@ -906,18 +899,18 @@ usm_ndarray_linear_sequence_step(py::object start,
906
899
" usm_ndarray_linspace: Expecting 1D array to populate" );
907
900
}
908
901
909
- int flags = dst.get_flags ();
910
- if (!(flags & USM_ARRAY_C_CONTIGUOUS)) {
902
+ if (!dst.is_c_contiguous ()) {
911
903
throw py::value_error (
912
904
" usm_ndarray_linspace: Non-contiguous arrays are not supported" );
913
905
}
914
906
915
907
sycl::queue dst_q = dst.get_queue ();
916
- if (dst_q != exec_q && dst_q. get_context () != exec_q. get_context ( )) {
908
+ if (! dpctl::utils::queues_are_compatible ( exec_q, { dst_q} )) {
917
909
throw py::value_error (
918
- " Execution queue context is not the same as allocation context " );
910
+ " Execution queue is not compatible with the allocation queue " );
919
911
}
920
912
913
+ auto array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
921
914
int dst_typenum = dst.get_typenum ();
922
915
int dst_typeid = array_types.typenum_to_lookup_id (dst_typenum);
923
916
@@ -955,18 +948,18 @@ usm_ndarray_linear_sequence_affine(py::object start,
955
948
" usm_ndarray_linspace: Expecting 1D array to populate" );
956
949
}
957
950
958
- int flags = dst.get_flags ();
959
- if (!(flags & USM_ARRAY_C_CONTIGUOUS)) {
951
+ if (!dst.is_c_contiguous ()) {
960
952
throw py::value_error (
961
953
" usm_ndarray_linspace: Non-contiguous arrays are not supported" );
962
954
}
963
955
964
956
sycl::queue dst_q = dst.get_queue ();
965
- if (dst_q != exec_q && dst_q. get_context () != exec_q. get_context ( )) {
957
+ if (! dpctl::utils::queues_are_compatible ( exec_q, { dst_q} )) {
966
958
throw py::value_error (
967
959
" Execution queue context is not the same as allocation context" );
968
960
}
969
961
962
+ auto array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
970
963
int dst_typenum = dst.get_typenum ();
971
964
int dst_typeid = array_types.typenum_to_lookup_id (dst_typenum);
972
965
@@ -1010,23 +1003,20 @@ usm_ndarray_full(py::object py_value,
1010
1003
return std::make_pair (sycl::event (), sycl::event ());
1011
1004
}
1012
1005
1013
- int dst_flags = dst.get_flags ();
1014
-
1015
1006
sycl::queue dst_q = dst.get_queue ();
1016
- if (dst_q != exec_q && dst_q. get_context () != exec_q. get_context ( )) {
1007
+ if (! dpctl::utils::queues_are_compatible ( exec_q, { dst_q} )) {
1017
1008
throw py::value_error (
1018
- " Execution queue context is not the same as allocation context " );
1009
+ " Execution queue is not compatible with the allocation queue " );
1019
1010
}
1020
1011
1012
+ auto array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
1021
1013
int dst_typenum = dst.get_typenum ();
1022
1014
int dst_typeid = array_types.typenum_to_lookup_id (dst_typenum);
1023
1015
1024
1016
char *dst_data = dst.get_data ();
1025
1017
sycl::event full_event;
1026
1018
1027
- if (dst_nelems == 1 || (dst_flags & USM_ARRAY_C_CONTIGUOUS) ||
1028
- (dst_flags & USM_ARRAY_F_CONTIGUOUS))
1029
- {
1019
+ if (dst_nelems == 1 || dst.is_c_contiguous () || dst.is_f_contiguous ()) {
1030
1020
auto fn = full_contig_dispatch_vector[dst_typeid];
1031
1021
1032
1022
sycl::event full_contig_event =
@@ -1068,6 +1058,7 @@ eye(py::ssize_t k,
1068
1058
" allocation queue" );
1069
1059
}
1070
1060
1061
+ auto array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
1071
1062
int dst_typenum = dst.get_typenum ();
1072
1063
int dst_typeid = array_types.typenum_to_lookup_id (dst_typenum);
1073
1064
@@ -1079,8 +1070,8 @@ eye(py::ssize_t k,
1079
1070
return std::make_pair (sycl::event{}, sycl::event{});
1080
1071
}
1081
1072
1082
- bool is_dst_c_contig = (( dst.get_flags () & USM_ARRAY_C_CONTIGUOUS) != 0 );
1083
- bool is_dst_f_contig = (( dst.get_flags () & USM_ARRAY_F_CONTIGUOUS) != 0 );
1073
+ bool is_dst_c_contig = dst.is_c_contiguous ( );
1074
+ bool is_dst_f_contig = dst.is_f_contiguous ( );
1084
1075
if (!is_dst_c_contig && !is_dst_f_contig) {
1085
1076
throw py::value_error (" USM array is not contiguous" );
1086
1077
}
@@ -1182,6 +1173,8 @@ tri(sycl::queue &exec_q,
1182
1173
throw py::value_error (" Arrays index overlapping segments of memory" );
1183
1174
}
1184
1175
1176
+ auto array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
1177
+
1185
1178
int src_typenum = src.get_typenum ();
1186
1179
int dst_typenum = dst.get_typenum ();
1187
1180
int src_typeid = array_types.typenum_to_lookup_id (src_typenum);
@@ -1203,9 +1196,8 @@ tri(sycl::queue &exec_q,
1203
1196
using shT = std::vector<py::ssize_t >;
1204
1197
shT src_strides (src_nd);
1205
1198
1206
- int src_flags = src.get_flags ();
1207
- bool is_src_c_contig = ((src_flags & USM_ARRAY_C_CONTIGUOUS) != 0 );
1208
- bool is_src_f_contig = ((src_flags & USM_ARRAY_F_CONTIGUOUS) != 0 );
1199
+ bool is_src_c_contig = src.is_c_contiguous ();
1200
+ bool is_src_f_contig = src.is_f_contiguous ();
1209
1201
1210
1202
const py::ssize_t *src_strides_raw = src.get_strides_raw ();
1211
1203
if (src_strides_raw == nullptr ) {
@@ -1227,9 +1219,8 @@ tri(sycl::queue &exec_q,
1227
1219
1228
1220
shT dst_strides (src_nd);
1229
1221
1230
- int dst_flags = dst.get_flags ();
1231
- bool is_dst_c_contig = ((dst_flags & USM_ARRAY_C_CONTIGUOUS) != 0 );
1232
- bool is_dst_f_contig = ((dst_flags & USM_ARRAY_F_CONTIGUOUS) != 0 );
1222
+ bool is_dst_c_contig = dst.is_c_contiguous ();
1223
+ bool is_dst_f_contig = dst.is_f_contiguous ();
1233
1224
1234
1225
const py::ssize_t *dst_strides_raw = dst.get_strides_raw ();
1235
1226
if (dst_strides_raw == nullptr ) {
@@ -1457,9 +1448,6 @@ PYBIND11_MODULE(_tensor_impl, m)
1457
1448
init_copy_for_reshape_dispatch_vector ();
1458
1449
import_dpctl ();
1459
1450
1460
- // populate types constants for type dispatching functions
1461
- array_types = dpctl::tensor::detail::usm_ndarray_types::get ();
1462
-
1463
1451
m.def (
1464
1452
" _contract_iter" , &contract_iter,
1465
1453
" Simplifies iteration of array of given shape & stride. Returns "
0 commit comments