@@ -766,22 +766,26 @@ def _take_multi_index(ary, inds, p):
766
766
raise TypeError (
767
767
f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
768
768
)
769
+ ary_nd = ary .ndim
770
+ p = normalize_axis_index (operator .index (p ), ary_nd )
769
771
queues_ = [
770
772
ary .sycl_queue ,
771
773
]
772
774
usm_types_ = [
773
775
ary .usm_type ,
774
776
]
775
- if not isinstance (inds , list ) and not isinstance ( inds , tuple ):
777
+ if not isinstance (inds , ( list , tuple ) ):
776
778
inds = (inds ,)
777
- all_integers = True
778
779
for ind in inds :
779
780
if not isinstance (ind , dpt .usm_ndarray ):
780
781
raise TypeError ("all elements of `ind` expected to be usm_ndarrays" )
781
782
queues_ .append (ind .sycl_queue )
782
783
usm_types_ .append (ind .usm_type )
783
- if all_integers :
784
- all_integers = ind .dtype .kind in "ui"
784
+ if ind .dtype .kind not in "ui" :
785
+ raise IndexError (
786
+ "arrays used as indices must be of integer (or boolean) type"
787
+ )
788
+ res_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
785
789
exec_q = dpctl .utils .get_execution_queue (queues_ )
786
790
if exec_q is None :
787
791
raise dpctl .utils .ExecutionPlacementError (
@@ -790,30 +794,36 @@ def _take_multi_index(ary, inds, p):
790
794
"Use `usm_ndarray.to_device` method to migrate data to "
791
795
"be associated with the same queue."
792
796
)
793
- if not all_integers :
794
- raise IndexError (
795
- "arrays used as indices must be of integer (or boolean) type"
796
- )
797
797
if len (inds ) > 1 :
798
+ ind_dt = dpt .result_type (* inds )
799
+ # ind arrays have been checked to be of integer dtype
800
+ if ind_dt .kind not in "ui" :
801
+ raise ValueError (
802
+ "cannot safely promote indices to an integer data type"
803
+ )
804
+ inds = tuple (
805
+ map (
806
+ lambda ind : ind
807
+ if ind .dtype == ind_dt
808
+ else dpt .astype (ind , ind_dt ),
809
+ inds ,
810
+ )
811
+ )
798
812
inds = dpt .broadcast_arrays (* inds )
813
+ ind0 = inds [0 ]
799
814
ary_sh = ary .shape
800
- ary_nd = ary .ndim
801
- p = normalize_axis_index (operator .index (p ), ary_nd )
802
815
p_end = p + len (inds )
803
- inds_sz = inds [ 0 ] .size
816
+ inds_sz = ind0 .size
804
817
if 0 in ary_sh [p : p_end + 1 ] and inds_sz != 0 :
805
818
raise IndexError ("cannot take non-empty indices from an empty axis" )
806
- res_shape = ary_sh [:p ] + inds [0 ].shape + ary_sh [p + len (inds ) :]
807
- res_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
819
+ res_shape = ary_sh [:p ] + ind0 .shape + ary_sh [p_end :]
808
820
res = dpt .empty (
809
821
res_shape , dtype = ary .dtype , usm_type = res_usm_type , sycl_queue = exec_q
810
822
)
811
-
812
823
hev , _ = ti ._take (
813
824
src = ary , ind = inds , dst = res , axis_start = p , mode = 0 , sycl_queue = exec_q
814
825
)
815
826
hev .wait ()
816
-
817
827
return res
818
828
819
829
@@ -881,6 +891,8 @@ def _put_multi_index(ary, inds, p, vals):
881
891
raise TypeError (
882
892
f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
883
893
)
894
+ ary_nd = ary .ndim
895
+ p = normalize_axis_index (operator .index (p ), ary_nd )
884
896
if isinstance (vals , dpt .usm_ndarray ):
885
897
queues_ = [ary .sycl_queue , vals .sycl_queue ]
886
898
usm_types_ = [ary .usm_type , vals .usm_type ]
@@ -891,22 +903,19 @@ def _put_multi_index(ary, inds, p, vals):
891
903
usm_types_ = [
892
904
ary .usm_type ,
893
905
]
894
- if not isinstance (inds , list ) and not isinstance ( inds , tuple ):
906
+ if not isinstance (inds , ( list , tuple ) ):
895
907
inds = (inds ,)
896
- all_integers = True
897
908
for ind in inds :
898
909
if not isinstance (ind , dpt .usm_ndarray ):
899
910
raise TypeError ("all elements of `ind` expected to be usm_ndarrays" )
900
911
queues_ .append (ind .sycl_queue )
901
912
usm_types_ .append (ind .usm_type )
902
- if all_integers :
903
- all_integers = ind .dtype .kind in "ui"
904
- if not all_integers :
905
- raise IndexError (
906
- "arrays used as indices must be of integer (or boolean) type"
907
- )
908
- exec_q = dpctl .utils .get_execution_queue (queues_ )
913
+ if ind .dtype .kind not in "ui" :
914
+ raise IndexError (
915
+ "arrays used as indices must be of integer (or boolean) type"
916
+ )
909
917
vals_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
918
+ exec_q = dpctl .utils .get_execution_queue (queues_ )
910
919
if exec_q is not None :
911
920
if not isinstance (vals , dpt .usm_ndarray ):
912
921
vals = dpt .asarray (
@@ -922,19 +931,28 @@ def _put_multi_index(ary, inds, p, vals):
922
931
"be associated with the same queue."
923
932
)
924
933
if len (inds ) > 1 :
934
+ ind_dt = dpt .result_type (* inds )
935
+ # ind arrays have been checked to be of integer dtype
936
+ if ind_dt .kind not in "ui" :
937
+ raise ValueError (
938
+ "cannot safely promote indices to an integer data type"
939
+ )
940
+ inds = tuple (
941
+ map (
942
+ lambda ind : ind
943
+ if ind .dtype == ind_dt
944
+ else dpt .astype (ind , ind_dt ),
945
+ inds ,
946
+ )
947
+ )
925
948
inds = dpt .broadcast_arrays (* inds )
949
+ ind0 = inds [0 ]
926
950
ary_sh = ary .shape
927
- ary_nd = ary .ndim
928
- p = normalize_axis_index (operator .index (p ), ary_nd )
929
951
p_end = p + len (inds )
930
- inds_sz = inds [ 0 ] .size
952
+ inds_sz = ind0 .size
931
953
if 0 in ary_sh [p : p_end + 1 ] and inds_sz != 0 :
932
- raise IndexError (
933
- "cannot put elements at non-empty indices along empty axis"
934
- )
935
- expected_vals_shape = (
936
- ary .shape [:p ] + inds [0 ].shape + ary .shape [p + len (inds ) :]
937
- )
954
+ raise IndexError ("cannot put into non-empty indices from an empty axis" )
955
+ expected_vals_shape = ary_sh [:p ] + ind0 .shape + ary_sh [p_end :]
938
956
if vals .dtype == ary .dtype :
939
957
rhs = vals
940
958
else :
@@ -944,5 +962,4 @@ def _put_multi_index(ary, inds, p, vals):
944
962
dst = ary , ind = inds , val = rhs , axis_start = p , mode = 0 , sycl_queue = exec_q
945
963
)
946
964
hev .wait ()
947
-
948
965
return
0 commit comments