@@ -930,126 +930,134 @@ def __call__(self, o1, o2, /, *, out=None, order="K"):
930
930
return out
931
931
932
932
def _inplace_op (self , o1 , o2 ):
933
- if not isinstance (o1 , dpt .usm_ndarray ):
934
- raise TypeError (
935
- "Expected first argument to be "
936
- f"dpctl.tensor.usm_ndarray, got { type (o1 )} "
937
- )
938
- if not o1 .flags .writable :
939
- raise ValueError ("provided left-hand side array is read-only" )
940
- q1 , o1_usm_type = o1 .sycl_queue , o1 .usm_type
941
- q2 , o2_usm_type = _get_queue_usm_type (o2 )
942
- if q2 is None :
943
- exec_q = q1
944
- res_usm_type = o1_usm_type
945
- else :
946
- exec_q = dpctl .utils .get_execution_queue ((q1 , q2 ))
947
- if exec_q is None :
948
- raise ExecutionPlacementError (
949
- "Execution placement can not be unambiguously inferred "
950
- "from input arguments."
933
+ if self .binary_inplace_fn_ is not None :
934
+ if not isinstance (o1 , dpt .usm_ndarray ):
935
+ raise TypeError (
936
+ "Expected first argument to be "
937
+ f"dpctl.tensor.usm_ndarray, got { type (o1 )} "
951
938
)
952
- res_usm_type = dpctl .utils .get_coerced_usm_type (
953
- (
954
- o1_usm_type ,
955
- o2_usm_type ,
939
+ if not o1 .flags .writable :
940
+ raise ValueError ("provided left-hand side array is read-only" )
941
+ q1 , o1_usm_type = o1 .sycl_queue , o1 .usm_type
942
+ q2 , o2_usm_type = _get_queue_usm_type (o2 )
943
+ if q2 is None :
944
+ exec_q = q1
945
+ res_usm_type = o1_usm_type
946
+ else :
947
+ exec_q = dpctl .utils .get_execution_queue ((q1 , q2 ))
948
+ if exec_q is None :
949
+ raise ExecutionPlacementError (
950
+ "Execution placement can not be unambiguously inferred "
951
+ "from input arguments."
952
+ )
953
+ res_usm_type = dpctl .utils .get_coerced_usm_type (
954
+ (
955
+ o1_usm_type ,
956
+ o2_usm_type ,
957
+ )
956
958
)
959
+ dpctl .utils .validate_usm_type (res_usm_type , allow_none = False )
960
+ o1_shape = o1 .shape
961
+ o2_shape = _get_shape (o2 )
962
+ if not isinstance (o2_shape , (tuple , list )):
963
+ raise TypeError (
964
+ "Shape of second argument can not be inferred. "
965
+ "Expected list or tuple."
966
+ )
967
+ try :
968
+ res_shape = _broadcast_shape_impl (
969
+ [
970
+ o1_shape ,
971
+ o2_shape ,
972
+ ]
973
+ )
974
+ except ValueError :
975
+ raise ValueError (
976
+ "operands could not be broadcast together with shapes "
977
+ f"{ o1_shape } and { o2_shape } "
978
+ )
979
+ if res_shape != o1_shape :
980
+ raise ValueError ("" )
981
+ sycl_dev = exec_q .sycl_device
982
+ o1_dtype = o1 .dtype
983
+ o2_dtype = _get_dtype (o2 , sycl_dev )
984
+ if not _validate_dtype (o2_dtype ):
985
+ raise ValueError ("Operand has an unsupported data type" )
986
+
987
+ o1_dtype , o2_dtype = self .weak_type_resolver_ (
988
+ o1_dtype , o2_dtype , sycl_dev
957
989
)
958
- dpctl .utils .validate_usm_type (res_usm_type , allow_none = False )
959
- o1_shape = o1 .shape
960
- o2_shape = _get_shape (o2 )
961
- if not isinstance (o2_shape , (tuple , list )):
962
- raise TypeError (
963
- "Shape of second argument can not be inferred. "
964
- "Expected list or tuple."
965
- )
966
- try :
967
- res_shape = _broadcast_shape_impl (
968
- [
969
- o1_shape ,
970
- o2_shape ,
971
- ]
972
- )
973
- except ValueError :
974
- raise ValueError (
975
- "operands could not be broadcast together with shapes "
976
- f"{ o1_shape } and { o2_shape } "
990
+
991
+ buf_dt , res_dt = _find_buf_dtype_in_place_op (
992
+ o1_dtype ,
993
+ o2_dtype ,
994
+ self .result_type_resolver_fn_ ,
995
+ sycl_dev ,
977
996
)
978
- if res_shape != o1_shape :
979
- raise ValueError ("" )
980
- sycl_dev = exec_q .sycl_device
981
- o1_dtype = o1 .dtype
982
- o2_dtype = _get_dtype (o2 , sycl_dev )
983
- if not _validate_dtype (o2_dtype ):
984
- raise ValueError ("Operand has an unsupported data type" )
985
997
986
- o1_dtype , o2_dtype = self .weak_type_resolver_ (
987
- o1_dtype , o2_dtype , sycl_dev
988
- )
998
+ if res_dt is None :
999
+ raise ValueError (
1000
+ f"function '{ self .name_ } ' does not support input types "
1001
+ f"({ o1_dtype } , { o2_dtype } ), "
1002
+ "and the inputs could not be safely coerced to any "
1003
+ "supported types according to the casting rule "
1004
+ "''same_kind''."
1005
+ )
989
1006
990
- buf_dt , res_dt = _find_buf_dtype_in_place_op (
991
- o1_dtype ,
992
- o2_dtype ,
993
- self .result_type_resolver_fn_ ,
994
- sycl_dev ,
995
- )
1007
+ if res_dt != o1_dtype :
1008
+ raise ValueError (
1009
+ f"Output array of type { res_dt } is needed, "
1010
+ f"got { o1_dtype } "
1011
+ )
996
1012
997
- if res_dt is None :
998
- raise ValueError (
999
- f"function '{ self .name_ } ' does not support input types "
1000
- f"({ o1_dtype } , { o2_dtype } ), "
1001
- "and the inputs could not be safely coerced to any "
1002
- "supported types according to the casting rule ''same_kind''."
1003
- )
1013
+ _manager = SequentialOrderManager [exec_q ]
1014
+ if isinstance (o2 , dpt .usm_ndarray ):
1015
+ src2 = o2
1016
+ if (
1017
+ ti ._array_overlap (o2 , o1 )
1018
+ and not ti ._same_logical_tensors (o2 , o1 )
1019
+ and buf_dt is None
1020
+ ):
1021
+ buf_dt = o2_dtype
1022
+ else :
1023
+ src2 = dpt .asarray (o2 , dtype = o2_dtype , sycl_queue = exec_q )
1024
+ if buf_dt is None :
1025
+ if src2 .shape != res_shape :
1026
+ src2 = dpt .broadcast_to (src2 , res_shape )
1027
+ dep_evs = _manager .submitted_events
1028
+ ht_ , comp_ev = self .binary_inplace_fn_ (
1029
+ lhs = o1 ,
1030
+ rhs = src2 ,
1031
+ sycl_queue = exec_q ,
1032
+ depends = dep_evs ,
1033
+ )
1034
+ _manager .add_event_pair (ht_ , comp_ev )
1035
+ else :
1036
+ buf = dpt .empty_like (src2 , dtype = buf_dt )
1037
+ dep_evs = _manager .submitted_events
1038
+ (
1039
+ ht_copy_ev ,
1040
+ copy_ev ,
1041
+ ) = ti ._copy_usm_ndarray_into_usm_ndarray (
1042
+ src = src2 ,
1043
+ dst = buf ,
1044
+ sycl_queue = exec_q ,
1045
+ depends = dep_evs ,
1046
+ )
1047
+ _manager .add_event_pair (ht_copy_ev , copy_ev )
1004
1048
1005
- if res_dt != o1_dtype :
1006
- raise ValueError (
1007
- f"Output array of type { res_dt } is needed, " f"got { o1_dtype } "
1008
- )
1049
+ buf = dpt .broadcast_to (buf , res_shape )
1050
+ ht_ , bf_ev = self .binary_inplace_fn_ (
1051
+ lhs = o1 ,
1052
+ rhs = buf ,
1053
+ sycl_queue = exec_q ,
1054
+ depends = [copy_ev ],
1055
+ )
1056
+ _manager .add_event_pair (ht_ , bf_ev )
1009
1057
1010
- _manager = SequentialOrderManager [exec_q ]
1011
- if isinstance (o2 , dpt .usm_ndarray ):
1012
- src2 = o2
1013
- if (
1014
- ti ._array_overlap (o2 , o1 )
1015
- and not ti ._same_logical_tensors (o2 , o1 )
1016
- and buf_dt is None
1017
- ):
1018
- buf_dt = o2_dtype
1019
- else :
1020
- src2 = dpt .asarray (o2 , dtype = o2_dtype , sycl_queue = exec_q )
1021
- if buf_dt is None :
1022
- if src2 .shape != res_shape :
1023
- src2 = dpt .broadcast_to (src2 , res_shape )
1024
- dep_evs = _manager .submitted_events
1025
- ht_ , comp_ev = self .binary_inplace_fn_ (
1026
- lhs = o1 ,
1027
- rhs = src2 ,
1028
- sycl_queue = exec_q ,
1029
- depends = dep_evs ,
1030
- )
1031
- _manager .add_event_pair (ht_ , comp_ev )
1058
+ return o1
1032
1059
else :
1033
- buf = dpt .empty_like (src2 , dtype = buf_dt )
1034
- dep_evs = _manager .submitted_events
1035
- (
1036
- ht_copy_ev ,
1037
- copy_ev ,
1038
- ) = ti ._copy_usm_ndarray_into_usm_ndarray (
1039
- src = src2 ,
1040
- dst = buf ,
1041
- sycl_queue = exec_q ,
1042
- depends = dep_evs ,
1043
- )
1044
- _manager .add_event_pair (ht_copy_ev , copy_ev )
1045
-
1046
- buf = dpt .broadcast_to (buf , res_shape )
1047
- ht_ , bf_ev = self .binary_inplace_fn_ (
1048
- lhs = o1 ,
1049
- rhs = buf ,
1050
- sycl_queue = exec_q ,
1051
- depends = [copy_ev ],
1060
+ raise ValueError (
1061
+ "binary function does not have a dedicated in-place "
1062
+ "implementation"
1052
1063
)
1053
- _manager .add_event_pair (ht_ , bf_ev )
1054
-
1055
- return o1
0 commit comments