@@ -44,58 +44,6 @@ include "_types.pxi"
44
44
include " _slicing.pxi"
45
45
46
46
47
- def _dispatch_unary_elementwise (ary , name ):
48
- try :
49
- mod = ary.__array_namespace__()
50
- except AttributeError :
51
- return NotImplemented
52
- if mod is None and " dpnp" in sys.modules:
53
- fn = getattr (sys.modules[" dpnp" ], name)
54
- if callable (fn):
55
- return fn(ary)
56
- elif hasattr (mod, name):
57
- fn = getattr (mod, name)
58
- if callable (fn):
59
- return fn(ary)
60
-
61
- return NotImplemented
62
-
63
-
64
- def _dispatch_binary_elementwise (ary , name , other ):
65
- try :
66
- mod = ary.__array_namespace__()
67
- except AttributeError :
68
- return NotImplemented
69
- if mod is None and " dpnp" in sys.modules:
70
- fn = getattr (sys.modules[" dpnp" ], name)
71
- if callable (fn):
72
- return fn(ary, other)
73
- elif hasattr (mod, name):
74
- fn = getattr (mod, name)
75
- if callable (fn):
76
- return fn(ary, other)
77
-
78
- return NotImplemented
79
-
80
-
81
- def _dispatch_binary_elementwise2 (other , name , ary ):
82
- try :
83
- mod = ary.__array_namespace__()
84
- except AttributeError :
85
- return NotImplemented
86
- mod = ary.__array_namespace__()
87
- if mod is None and " dpnp" in sys.modules:
88
- fn = getattr (sys.modules[" dpnp" ], name)
89
- if callable (fn):
90
- return fn(other, ary)
91
- elif hasattr (mod, name):
92
- fn = getattr (mod, name)
93
- if callable (fn):
94
- return fn(other, ary)
95
-
96
- return NotImplemented
97
-
98
-
99
47
cdef class InternalUSMArrayError(Exception ):
100
48
"""
101
49
A InternalError exception is raised when internal
@@ -200,28 +148,6 @@ cdef class usm_ndarray:
200
148
PyMem_Free(self .strides_)
201
149
self ._reset()
202
150
203
- cdef usm_ndarray _clone(usm_ndarray self ):
204
- """
205
- Provides a copy of Python object pointing to the same data
206
- """
207
- cdef Py_ssize_t offset_elems = self .get_offset()
208
- cdef usm_ndarray res = usm_ndarray.__new__ (
209
- usm_ndarray, _make_int_tuple(self .nd_, self .shape_),
210
- dtype = _make_typestr(self .typenum_),
211
- strides = (
212
- _make_int_tuple(self .nd_, self .strides_) if (self .strides_)
213
- else None ),
214
- buffer = self .base_,
215
- offset = offset_elems,
216
- order = (' C' if (self .flags_ & USM_ARRAY_C_CONTIGUOUS) else ' F' )
217
- )
218
- res.flags_ = self .flags_
219
- res.array_namespace_ = self .array_namespace_
220
- if (res.data_ != self .data_):
221
- raise InternalUSMArrayError(
222
- " Data pointers of cloned and original objects are different." )
223
- return res
224
-
225
151
def __cinit__ (self , shape , dtype = None , strides = None , buffer = ' device' ,
226
152
Py_ssize_t offset = 0 , order = ' C' ,
227
153
buffer_ctor_kwargs = dict (),
@@ -966,32 +892,17 @@ cdef class usm_ndarray:
966
892
raise IndexError (" only integer arrays are valid indices" )
967
893
968
894
def __abs__ (self ):
969
- return _dispatch_unary_elementwise (self , " abs " )
895
+ return dpctl.tensor.abs (self )
970
896
971
897
def __add__ (first , other ):
972
898
"""
973
- Cython 0.* never calls `__radd__`, always calls `__add__`
974
- but first argument need not be an instance of this class,
975
- so dispatching is needed.
976
-
977
- This changes in Cython 3.0, where first is guaranteed to
978
- be `self`.
979
-
980
- [1] http://docs.cython.org/en/latest/src/userguide/special_methods.html
899
+ Implementation for operator.add
981
900
"""
982
- if isinstance (first, usm_ndarray):
983
- return _dispatch_binary_elementwise(first, " add" , other)
984
- elif isinstance (other, usm_ndarray):
985
- return _dispatch_binary_elementwise2(first, " add" , other)
986
- return NotImplemented
901
+ return dpctl.tensor.add(first, other)
987
902
988
903
def __and__ (first , other ):
989
- " See comment in __add__"
990
- if isinstance (first, usm_ndarray):
991
- return _dispatch_binary_elementwise(first, " bitwise_and" , other)
992
- elif isinstance (other, usm_ndarray):
993
- return _dispatch_binary_elementwise2(first, " bitwise_and" , other)
994
- return NotImplemented
904
+ " Implementation for operator.and"
905
+ return dpctl.tensor.bitwise_and(first, other)
995
906
996
907
def __dlpack__ (self , stream = None ):
997
908
"""
@@ -1037,27 +948,22 @@ cdef class usm_ndarray:
1037
948
)
1038
949
1039
950
def __eq__ (self , other ):
1040
- return _dispatch_binary_elementwise (self , " equal " , other)
951
+ return dpctl.tensor.equal (self , other)
1041
952
1042
953
def __floordiv__ (first , other ):
1043
- " See comment in __add__"
1044
- if isinstance (first, usm_ndarray):
1045
- return _dispatch_binary_elementwise(first, " floor_divide" , other)
1046
- elif isinstance (other, usm_ndarray):
1047
- return _dispatch_binary_elementwise2(first, " floor_divide" , other)
1048
- return NotImplemented
954
+ return dpctl.tensor.floor_divide(first, other)
1049
955
1050
956
def __ge__ (self , other ):
1051
- return _dispatch_binary_elementwise (self , " greater_equal " , other)
957
+ return dpctl.tensor.greater_equal (self , other)
1052
958
1053
959
def __gt__ (self , other ):
1054
- return _dispatch_binary_elementwise (self , " greater " , other)
960
+ return dpctl.tensor.greater (self , other)
1055
961
1056
962
def __invert__ (self ):
1057
- return _dispatch_unary_elementwise (self , " bitwise_invert " )
963
+ return dpctl.tensor.bitwise_invert (self )
1058
964
1059
965
def __le__ (self , other ):
1060
- return _dispatch_binary_elementwise (self , " less_equal " , other)
966
+ return dpctl.tensor.less_equal (self , other)
1061
967
1062
968
def __len__ (self ):
1063
969
if (self .nd_):
@@ -1067,72 +973,40 @@ cdef class usm_ndarray:
1067
973
1068
974
def __lshift__ (first , other ):
1069
975
" See comment in __add__"
1070
- if isinstance (first, usm_ndarray):
1071
- return _dispatch_binary_elementwise(first, " bitwise_left_shift" , other)
1072
- elif isinstance (other, usm_ndarray):
1073
- return _dispatch_binary_elementwise2(first, " bitwise_left_shift" , other)
1074
- return NotImplemented
976
+ return dpctl.tensor.bitwise_left_shift(first, other)
1075
977
1076
978
def __lt__ (self , other ):
1077
- return _dispatch_binary_elementwise (self , " less " , other)
979
+ return dpctl.tensor.less (self , other)
1078
980
1079
981
def __matmul__ (first , other ):
1080
- " See comment in __add__"
1081
- if isinstance (first, usm_ndarray):
1082
- return _dispatch_binary_elementwise(first, " matmul" , other)
1083
- elif isinstance (other, usm_ndarray):
1084
- return _dispatch_binary_elementwise2(first, " matmul" , other)
1085
982
return NotImplemented
1086
983
1087
984
def __mod__ (first , other ):
1088
- " See comment in __add__"
1089
- if isinstance (first, usm_ndarray):
1090
- return _dispatch_binary_elementwise(first, " remainder" , other)
1091
- elif isinstance (other, usm_ndarray):
1092
- return _dispatch_binary_elementwise2(first, " remainder" , other)
1093
- return NotImplemented
985
+ return dpctl.tensor.remainder(first, other)
1094
986
1095
987
def __mul__ (first , other ):
1096
- " See comment in __add__"
1097
- if isinstance (first, usm_ndarray):
1098
- return _dispatch_binary_elementwise(first, " multiply" , other)
1099
- elif isinstance (other, usm_ndarray):
1100
- return _dispatch_binary_elementwise2(first, " multiply" , other)
1101
- return NotImplemented
988
+ return dpctl.tensor.multiply(first, other)
1102
989
1103
990
def __ne__ (self , other ):
1104
- return _dispatch_binary_elementwise (self , " not_equal " , other)
991
+ return dpctl.tensor.not_equal (self , other)
1105
992
1106
993
def __neg__ (self ):
1107
- return _dispatch_unary_elementwise (self , " negative " )
994
+ return dpctl.tensor.negative (self )
1108
995
1109
996
def __or__ (first , other ):
1110
- " See comment in __add__"
1111
- if isinstance (first, usm_ndarray):
1112
- return _dispatch_binary_elementwise(first, " bitwise_or" , other)
1113
- elif isinstance (other, usm_ndarray):
1114
- return _dispatch_binary_elementwise2(first, " bitwise_or" , other)
1115
- return NotImplemented
997
+ return dpctl.tensor.bitwise_or(first, other)
1116
998
1117
999
def __pos__ (self ):
1118
- return self # _dispatch_unary_elementwise (self, "positive" )
1000
+ return dpctl.tensor.positive (self )
1119
1001
1120
1002
def __pow__ (first , other , mod ):
1121
- " See comment in __add__"
1122
1003
if mod is None :
1123
- if isinstance (first, usm_ndarray):
1124
- return _dispatch_binary_elementwise(first, " pow" , other)
1125
- elif isinstance (other, usm_ndarray):
1126
- return _dispatch_binary_elementwise(first, " pow" , other)
1127
- return NotImplemented
1004
+ return dpctl.tensor.pow(first, other)
1005
+ else :
1006
+ return NotImplemented
1128
1007
1129
1008
def __rshift__ (first , other ):
1130
- " See comment in __add__"
1131
- if isinstance (first, usm_ndarray):
1132
- return _dispatch_binary_elementwise(first, " bitwise_right_shift" , other)
1133
- elif isinstance (other, usm_ndarray):
1134
- return _dispatch_binary_elementwise2(first, " bitwise_right_shift" , other)
1135
- return NotImplemented
1009
+ return dpctl.tensor.bitwise_right_shift(first, other)
1136
1010
1137
1011
def __setitem__ (self , key , rhs ):
1138
1012
cdef tuple _meta
@@ -1223,67 +1097,52 @@ cdef class usm_ndarray:
1223
1097
1224
1098
1225
1099
def __sub__ (first , other ):
1226
- " See comment in __add__"
1227
- if isinstance (first, usm_ndarray):
1228
- return _dispatch_binary_elementwise(first, " subtract" , other)
1229
- elif isinstance (other, usm_ndarray):
1230
- return _dispatch_binary_elementwise2(first, " subtract" , other)
1231
- return NotImplemented
1100
+ return dpctl.tensor.subtract(first, other)
1232
1101
1233
1102
def __truediv__ (first , other ):
1234
- " See comment in __add__"
1235
- if isinstance (first, usm_ndarray):
1236
- return _dispatch_binary_elementwise(first, " divide" , other)
1237
- elif isinstance (other, usm_ndarray):
1238
- return _dispatch_binary_elementwise2(first, " divide" , other)
1239
- return NotImplemented
1103
+ return dpctl.tensor.divide(first, other)
1240
1104
1241
1105
def __xor__ (first , other ):
1242
- " See comment in __add__"
1243
- if isinstance (first, usm_ndarray):
1244
- return _dispatch_binary_elementwise(first, " bitwise_xor" , other)
1245
- elif isinstance (other, usm_ndarray):
1246
- return _dispatch_binary_elementwise2(first, " bitwise_xor" , other)
1247
- return NotImplemented
1106
+ return dpctl.tensor.bitwise_xor(first, other)
1248
1107
1249
1108
def __radd__ (self , other ):
1250
- return _dispatch_binary_elementwise( self , " add " , other )
1109
+ return dpctl.tensor.add(other, self )
1251
1110
1252
1111
def __rand__ (self , other ):
1253
- return _dispatch_binary_elementwise( self , " bitwise_and " , other )
1112
+ return dpctl.tensor.bitwise_and(other, self )
1254
1113
1255
1114
def __rfloordiv__ (self , other ):
1256
- return _dispatch_binary_elementwise2 (other, " floor_divide " , self )
1115
+ return dpctl.tensor.floor_divide (other, self )
1257
1116
1258
1117
def __rlshift__ (self , other ):
1259
- return _dispatch_binary_elementwise2 (other, " bitwise_left_shift " , self )
1118
+ return dpctl.tensor.bitwise_left_shift (other, self )
1260
1119
1261
1120
def __rmatmul__ (self , other ):
1262
- return _dispatch_binary_elementwise2(other, " matmul " , self )
1121
+ return NotImplemented
1263
1122
1264
1123
def __rmod__ (self , other ):
1265
- return _dispatch_binary_elementwise2 (other, " remainder " , self )
1124
+ return dpctl.tensor.remainder (other, self )
1266
1125
1267
1126
def __rmul__ (self , other ):
1268
- return _dispatch_binary_elementwise( self , " multiply " , other )
1127
+ return dpctl.tensor.multiply(other, self )
1269
1128
1270
1129
def __ror__ (self , other ):
1271
- return _dispatch_binary_elementwise( self , " bitwise_or " , other )
1130
+ return dpctl.tensor.bitwise_or(other, self )
1272
1131
1273
1132
def __rpow__ (self , other ):
1274
- return _dispatch_binary_elementwise2 (other, " pow " , self )
1133
+ return dpctl.tensor.pow (other, self )
1275
1134
1276
1135
def __rrshift__ (self , other ):
1277
- return _dispatch_binary_elementwise2 (other, " bitwise_right_shift " , self )
1136
+ return dpctl.tensor.bitwise_right_shift (other, self )
1278
1137
1279
1138
def __rsub__ (self , other ):
1280
- return _dispatch_binary_elementwise2 (other, " subtract " , self )
1139
+ return dpctl.tensor.subtract (other, self )
1281
1140
1282
1141
def __rtruediv__ (self , other ):
1283
- return _dispatch_binary_elementwise2 (other, " divide " , self )
1142
+ return dpctl.tensor.divide (other, self )
1284
1143
1285
1144
def __rxor__ (self , other ):
1286
- return _dispatch_binary_elementwise2 (other, " bitwise_xor " , self )
1145
+ return dpctl.tensor.bitwise_xor (other, self )
1287
1146
1288
1147
def __iadd__ (self , other ):
1289
1148
from ._elementwise_funcs import add
0 commit comments