@@ -763,7 +763,9 @@ def _nonzero_impl(ary):
763
763
764
764
def _take_multi_index (ary , inds , p ):
765
765
if not isinstance (ary , dpt .usm_ndarray ):
766
- raise TypeError
766
+ raise TypeError (
767
+ f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
768
+ )
767
769
queues_ = [
768
770
ary .sycl_queue ,
769
771
]
@@ -774,23 +776,34 @@ def _take_multi_index(ary, inds, p):
774
776
inds = (inds ,)
775
777
all_integers = True
776
778
for ind in inds :
779
+ if not isinstance (ind , dpt .usm_ndarray ):
780
+ raise TypeError ("all elements of `ind` expected to be usm_ndarrays" )
777
781
queues_ .append (ind .sycl_queue )
778
782
usm_types_ .append (ind .usm_type )
779
783
if all_integers :
780
784
all_integers = ind .dtype .kind in "ui"
781
785
exec_q = dpctl .utils .get_execution_queue (queues_ )
782
786
if exec_q is None :
783
- raise dpctl .utils .ExecutionPlacementError ("" )
787
+ raise dpctl .utils .ExecutionPlacementError (
788
+ "Can not automatically determine where to allocate the "
789
+ "result or performance execution. "
790
+ "Use `usm_ndarray.to_device` method to migrate data to "
791
+ "be associated with the same queue."
792
+ )
784
793
if not all_integers :
785
794
raise IndexError (
786
795
"arrays used as indices must be of integer (or boolean) type"
787
796
)
788
797
if len (inds ) > 1 :
789
798
inds = dpt .broadcast_arrays (* inds )
790
- ary_ndim = ary .ndim
791
- p = normalize_axis_index (operator .index (p ), ary_ndim )
792
-
793
- res_shape = ary .shape [:p ] + inds [0 ].shape + ary .shape [p + len (inds ) :]
799
+ ary_sh = ary .shape
800
+ ary_nd = ary .ndim
801
+ p = normalize_axis_index (operator .index (p ), ary_nd )
802
+ p_end = p + len (inds )
803
+ inds_sz = inds [0 ].size
804
+ if 0 in ary_sh [p : p_end + 1 ] and inds_sz != 0 :
805
+ raise IndexError ("cannot take non-empty indices from an empty axis" )
806
+ res_shape = ary_sh [:p ] + inds [0 ].shape + ary_sh [p + len (inds ) :]
794
807
res_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
795
808
res = dpt .empty (
796
809
res_shape , dtype = ary .dtype , usm_type = res_usm_type , sycl_queue = exec_q
@@ -864,6 +877,10 @@ def _place_impl(ary, ary_mask, vals, axis=0):
864
877
865
878
866
879
def _put_multi_index (ary , inds , p , vals ):
880
+ if not isinstance (ary , dpt .usm_ndarray ):
881
+ raise TypeError (
882
+ f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
883
+ )
867
884
if isinstance (vals , dpt .usm_ndarray ):
868
885
queues_ = [ary .sycl_queue , vals .sycl_queue ]
869
886
usm_types_ = [ary .usm_type , vals .usm_type ]
@@ -879,40 +896,52 @@ def _put_multi_index(ary, inds, p, vals):
879
896
all_integers = True
880
897
for ind in inds :
881
898
if not isinstance (ind , dpt .usm_ndarray ):
882
- raise TypeError
899
+ raise TypeError ( "all elements of `ind` expected to be usm_ndarrays" )
883
900
queues_ .append (ind .sycl_queue )
884
901
usm_types_ .append (ind .usm_type )
885
902
if all_integers :
886
903
all_integers = ind .dtype .kind in "ui"
904
+ if not all_integers :
905
+ raise IndexError (
906
+ "arrays used as indices must be of integer (or boolean) type"
907
+ )
887
908
exec_q = dpctl .utils .get_execution_queue (queues_ )
909
+ vals_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
910
+ if exec_q is not None :
911
+ if not isinstance (vals , dpt .usm_ndarray ):
912
+ vals = dpt .asarray (
913
+ vals , dtype = ary .dtype , usm_type = vals_usm_type , sycl_queue = exec_q
914
+ )
915
+ else :
916
+ exec_q = dpctl .utils .get_execution_queue ((exec_q , vals .sycl_queue ))
888
917
if exec_q is None :
889
918
raise dpctl .utils .ExecutionPlacementError (
890
919
"Can not automatically determine where to allocate the "
891
920
"result or performance execution. "
892
921
"Use `usm_ndarray.to_device` method to migrate data to "
893
922
"be associated with the same queue."
894
923
)
895
- if not all_integers :
896
- raise IndexError (
897
- "arrays used as indices must be of integer (or boolean) type"
898
- )
899
924
if len (inds ) > 1 :
900
925
inds = dpt .broadcast_arrays (* inds )
901
- ary_ndim = ary .ndim
902
-
903
- p = normalize_axis_index (operator .index (p ), ary_ndim )
904
- vals_shape = ary .shape [:p ] + inds [0 ].shape + ary .shape [p + len (inds ) :]
905
-
906
- vals_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
907
- if not isinstance (vals , dpt .usm_ndarray ):
908
- vals = dpt .asarray (
909
- vals , dtype = ary .dtype , usm_type = vals_usm_type , sycl_queue = exec_q
926
+ ary_sh = ary .shape
927
+ ary_nd = ary .ndim
928
+ p = normalize_axis_index (operator .index (p ), ary_nd )
929
+ p_end = p + len (inds )
930
+ inds_sz = inds [0 ].size
931
+ if 0 in ary_sh [p : p_end + 1 ] and inds_sz != 0 :
932
+ raise IndexError (
933
+ "cannot put elements at non-empty indices along empty axis"
910
934
)
911
-
912
- vals = dpt .broadcast_to (vals , vals_shape )
913
-
935
+ expected_vals_shape = (
936
+ ary .shape [:p ] + inds [0 ].shape + ary .shape [p + len (inds ) :]
937
+ )
938
+ if vals .dtype == ary .dtype :
939
+ rhs = vals
940
+ else :
941
+ rhs = dpt .astype (vals , ary .dtype )
942
+ rhs = dpt .broadcast_to (rhs , expected_vals_shape )
914
943
hev , _ = ti ._put (
915
- dst = ary , ind = inds , val = vals , axis_start = p , mode = 0 , sycl_queue = exec_q
944
+ dst = ary , ind = inds , val = rhs , axis_start = p , mode = 0 , sycl_queue = exec_q
916
945
)
917
946
hev .wait ()
918
947
0 commit comments