Skip to content

Commit c9e5872

Browse files
Clean up of operator special methods
1. Removed unused usm_ndarray._clone static C-only method 2. Removed _dispatch* utilities 3. Used direct calls to unary/binary operators in implementation of special methods
1 parent 345fdaa commit c9e5872

File tree

2 files changed

+39
-181
lines changed

2 files changed

+39
-181
lines changed

dpctl/tensor/_usmarray.pxd

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ cdef api class usm_ndarray [object PyUSMArrayObject, type PyUSMArrayType]:
5858

5959
cdef void _reset(usm_ndarray self)
6060
cdef void _cleanup(usm_ndarray self)
61-
cdef usm_ndarray _clone(usm_ndarray self)
6261
cdef Py_ssize_t get_offset(usm_ndarray self) except *
6362

6463
cdef char* get_data(self)

dpctl/tensor/_usmarray.pyx

Lines changed: 39 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -44,58 +44,6 @@ include "_types.pxi"
4444
include "_slicing.pxi"
4545

4646

47-
def _dispatch_unary_elementwise(ary, name):
48-
try:
49-
mod = ary.__array_namespace__()
50-
except AttributeError:
51-
return NotImplemented
52-
if mod is None and "dpnp" in sys.modules:
53-
fn = getattr(sys.modules["dpnp"], name)
54-
if callable(fn):
55-
return fn(ary)
56-
elif hasattr(mod, name):
57-
fn = getattr(mod, name)
58-
if callable(fn):
59-
return fn(ary)
60-
61-
return NotImplemented
62-
63-
64-
def _dispatch_binary_elementwise(ary, name, other):
65-
try:
66-
mod = ary.__array_namespace__()
67-
except AttributeError:
68-
return NotImplemented
69-
if mod is None and "dpnp" in sys.modules:
70-
fn = getattr(sys.modules["dpnp"], name)
71-
if callable(fn):
72-
return fn(ary, other)
73-
elif hasattr(mod, name):
74-
fn = getattr(mod, name)
75-
if callable(fn):
76-
return fn(ary, other)
77-
78-
return NotImplemented
79-
80-
81-
def _dispatch_binary_elementwise2(other, name, ary):
82-
try:
83-
mod = ary.__array_namespace__()
84-
except AttributeError:
85-
return NotImplemented
86-
mod = ary.__array_namespace__()
87-
if mod is None and "dpnp" in sys.modules:
88-
fn = getattr(sys.modules["dpnp"], name)
89-
if callable(fn):
90-
return fn(other, ary)
91-
elif hasattr(mod, name):
92-
fn = getattr(mod, name)
93-
if callable(fn):
94-
return fn(other, ary)
95-
96-
return NotImplemented
97-
98-
9947
cdef class InternalUSMArrayError(Exception):
10048
"""
10149
A InternalError exception is raised when internal
@@ -200,28 +148,6 @@ cdef class usm_ndarray:
200148
PyMem_Free(self.strides_)
201149
self._reset()
202150

203-
cdef usm_ndarray _clone(usm_ndarray self):
204-
"""
205-
Provides a copy of Python object pointing to the same data
206-
"""
207-
cdef Py_ssize_t offset_elems = self.get_offset()
208-
cdef usm_ndarray res = usm_ndarray.__new__(
209-
usm_ndarray, _make_int_tuple(self.nd_, self.shape_),
210-
dtype=_make_typestr(self.typenum_),
211-
strides=(
212-
_make_int_tuple(self.nd_, self.strides_) if (self.strides_)
213-
else None),
214-
buffer=self.base_,
215-
offset=offset_elems,
216-
order=('C' if (self.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
217-
)
218-
res.flags_ = self.flags_
219-
res.array_namespace_ = self.array_namespace_
220-
if (res.data_ != self.data_):
221-
raise InternalUSMArrayError(
222-
"Data pointers of cloned and original objects are different.")
223-
return res
224-
225151
def __cinit__(self, shape, dtype=None, strides=None, buffer='device',
226152
Py_ssize_t offset=0, order='C',
227153
buffer_ctor_kwargs=dict(),
@@ -966,32 +892,17 @@ cdef class usm_ndarray:
966892
raise IndexError("only integer arrays are valid indices")
967893

968894
def __abs__(self):
969-
return _dispatch_unary_elementwise(self, "abs")
895+
return dpctl.tensor.abs(self)
970896

971897
def __add__(first, other):
972898
"""
973-
Cython 0.* never calls `__radd__`, always calls `__add__`
974-
but first argument need not be an instance of this class,
975-
so dispatching is needed.
976-
977-
This changes in Cython 3.0, where first is guaranteed to
978-
be `self`.
979-
980-
[1] http://docs.cython.org/en/latest/src/userguide/special_methods.html
899+
Implementation for operator.add
981900
"""
982-
if isinstance(first, usm_ndarray):
983-
return _dispatch_binary_elementwise(first, "add", other)
984-
elif isinstance(other, usm_ndarray):
985-
return _dispatch_binary_elementwise2(first, "add", other)
986-
return NotImplemented
901+
return dpctl.tensor.add(first, other)
987902

988903
def __and__(first, other):
989-
"See comment in __add__"
990-
if isinstance(first, usm_ndarray):
991-
return _dispatch_binary_elementwise(first, "bitwise_and", other)
992-
elif isinstance(other, usm_ndarray):
993-
return _dispatch_binary_elementwise2(first, "bitwise_and", other)
994-
return NotImplemented
904+
"Implementation for operator.and"
905+
return dpctl.tensor.bitwise_and(first, other)
995906

996907
def __dlpack__(self, stream=None):
997908
"""
@@ -1037,27 +948,22 @@ cdef class usm_ndarray:
1037948
)
1038949

1039950
def __eq__(self, other):
1040-
return _dispatch_binary_elementwise(self, "equal", other)
951+
return dpctl.tensor.equal(self, other)
1041952

1042953
def __floordiv__(first, other):
1043-
"See comment in __add__"
1044-
if isinstance(first, usm_ndarray):
1045-
return _dispatch_binary_elementwise(first, "floor_divide", other)
1046-
elif isinstance(other, usm_ndarray):
1047-
return _dispatch_binary_elementwise2(first, "floor_divide", other)
1048-
return NotImplemented
954+
return dpctl.tensor.floor_divide(first, other)
1049955

1050956
def __ge__(self, other):
1051-
return _dispatch_binary_elementwise(self, "greater_equal", other)
957+
return dpctl.tensor.greater_equal(self, other)
1052958

1053959
def __gt__(self, other):
1054-
return _dispatch_binary_elementwise(self, "greater", other)
960+
return dpctl.tensor.greater(self, other)
1055961

1056962
def __invert__(self):
1057-
return _dispatch_unary_elementwise(self, "bitwise_invert")
963+
return dpctl.tensor.bitwise_invert(self)
1058964

1059965
def __le__(self, other):
1060-
return _dispatch_binary_elementwise(self, "less_equal", other)
966+
return dpctl.tensor.less_equal(self, other)
1061967

1062968
def __len__(self):
1063969
if (self.nd_):
@@ -1067,72 +973,40 @@ cdef class usm_ndarray:
1067973

1068974
def __lshift__(first, other):
1069975
"See comment in __add__"
1070-
if isinstance(first, usm_ndarray):
1071-
return _dispatch_binary_elementwise(first, "bitwise_left_shift", other)
1072-
elif isinstance(other, usm_ndarray):
1073-
return _dispatch_binary_elementwise2(first, "bitwise_left_shift", other)
1074-
return NotImplemented
976+
return dpctl.tensor.bitwise_left_shift(first, other)
1075977

1076978
def __lt__(self, other):
1077-
return _dispatch_binary_elementwise(self, "less", other)
979+
return dpctl.tensor.less(self, other)
1078980

1079981
def __matmul__(first, other):
1080-
"See comment in __add__"
1081-
if isinstance(first, usm_ndarray):
1082-
return _dispatch_binary_elementwise(first, "matmul", other)
1083-
elif isinstance(other, usm_ndarray):
1084-
return _dispatch_binary_elementwise2(first, "matmul", other)
1085982
return NotImplemented
1086983

1087984
def __mod__(first, other):
1088-
"See comment in __add__"
1089-
if isinstance(first, usm_ndarray):
1090-
return _dispatch_binary_elementwise(first, "remainder", other)
1091-
elif isinstance(other, usm_ndarray):
1092-
return _dispatch_binary_elementwise2(first, "remainder", other)
1093-
return NotImplemented
985+
return dpctl.tensor.remainder(first, other)
1094986

1095987
def __mul__(first, other):
1096-
"See comment in __add__"
1097-
if isinstance(first, usm_ndarray):
1098-
return _dispatch_binary_elementwise(first, "multiply", other)
1099-
elif isinstance(other, usm_ndarray):
1100-
return _dispatch_binary_elementwise2(first, "multiply", other)
1101-
return NotImplemented
988+
return dpctl.tensor.multiply(first, other)
1102989

1103990
def __ne__(self, other):
1104-
return _dispatch_binary_elementwise(self, "not_equal", other)
991+
return dpctl.tensor.not_equal(self, other)
1105992

1106993
def __neg__(self):
1107-
return _dispatch_unary_elementwise(self, "negative")
994+
return dpctl.tensor.negative(self)
1108995

1109996
def __or__(first, other):
1110-
"See comment in __add__"
1111-
if isinstance(first, usm_ndarray):
1112-
return _dispatch_binary_elementwise(first, "bitwise_or", other)
1113-
elif isinstance(other, usm_ndarray):
1114-
return _dispatch_binary_elementwise2(first, "bitwise_or", other)
1115-
return NotImplemented
997+
return dpctl.tensor.bitwise_or(first, other)
1116998

1117999
def __pos__(self):
1118-
return self # _dispatch_unary_elementwise(self, "positive")
1000+
return dpctl.tensor.positive(self)
11191001

11201002
def __pow__(first, other, mod):
1121-
"See comment in __add__"
11221003
if mod is None:
1123-
if isinstance(first, usm_ndarray):
1124-
return _dispatch_binary_elementwise(first, "pow", other)
1125-
elif isinstance(other, usm_ndarray):
1126-
return _dispatch_binary_elementwise(first, "pow", other)
1127-
return NotImplemented
1004+
return dpctl.tensor.pow(first, other)
1005+
else:
1006+
return NotImplemented
11281007

11291008
def __rshift__(first, other):
1130-
"See comment in __add__"
1131-
if isinstance(first, usm_ndarray):
1132-
return _dispatch_binary_elementwise(first, "bitwise_right_shift", other)
1133-
elif isinstance(other, usm_ndarray):
1134-
return _dispatch_binary_elementwise2(first, "bitwise_right_shift", other)
1135-
return NotImplemented
1009+
return dpctl.tensor.bitwise_right_shift(first, other)
11361010

11371011
def __setitem__(self, key, rhs):
11381012
cdef tuple _meta
@@ -1223,67 +1097,52 @@ cdef class usm_ndarray:
12231097

12241098

12251099
def __sub__(first, other):
1226-
"See comment in __add__"
1227-
if isinstance(first, usm_ndarray):
1228-
return _dispatch_binary_elementwise(first, "subtract", other)
1229-
elif isinstance(other, usm_ndarray):
1230-
return _dispatch_binary_elementwise2(first, "subtract", other)
1231-
return NotImplemented
1100+
return dpctl.tensor.subtract(first, other)
12321101

12331102
def __truediv__(first, other):
1234-
"See comment in __add__"
1235-
if isinstance(first, usm_ndarray):
1236-
return _dispatch_binary_elementwise(first, "divide", other)
1237-
elif isinstance(other, usm_ndarray):
1238-
return _dispatch_binary_elementwise2(first, "divide", other)
1239-
return NotImplemented
1103+
return dpctl.tensor.divide(first, other)
12401104

12411105
def __xor__(first, other):
1242-
"See comment in __add__"
1243-
if isinstance(first, usm_ndarray):
1244-
return _dispatch_binary_elementwise(first, "bitwise_xor", other)
1245-
elif isinstance(other, usm_ndarray):
1246-
return _dispatch_binary_elementwise2(first, "bitwise_xor", other)
1247-
return NotImplemented
1106+
return dpctl.tensor.bitwise_xor(first, other)
12481107

12491108
def __radd__(self, other):
1250-
return _dispatch_binary_elementwise(self, "add", other)
1109+
return dpctl.tensor.add(other, self)
12511110

12521111
def __rand__(self, other):
1253-
return _dispatch_binary_elementwise(self, "bitwise_and", other)
1112+
return dpctl.tensor.bitwise_and(other, self)
12541113

12551114
def __rfloordiv__(self, other):
1256-
return _dispatch_binary_elementwise2(other, "floor_divide", self)
1115+
return dpctl.tensor.floor_divide(other, self)
12571116

12581117
def __rlshift__(self, other):
1259-
return _dispatch_binary_elementwise2(other, "bitwise_left_shift", self)
1118+
return dpctl.tensor.bitwise_left_shift(other, self)
12601119

12611120
def __rmatmul__(self, other):
1262-
return _dispatch_binary_elementwise2(other, "matmul", self)
1121+
return NotImplemented
12631122

12641123
def __rmod__(self, other):
1265-
return _dispatch_binary_elementwise2(other, "remainder", self)
1124+
return dpctl.tensor.remainder(other, self)
12661125

12671126
def __rmul__(self, other):
1268-
return _dispatch_binary_elementwise(self, "multiply", other)
1127+
return dpctl.tensor.multiply(other, self)
12691128

12701129
def __ror__(self, other):
1271-
return _dispatch_binary_elementwise(self, "bitwise_or", other)
1130+
return dpctl.tensor.bitwise_or(other, self)
12721131

12731132
def __rpow__(self, other):
1274-
return _dispatch_binary_elementwise2(other, "pow", self)
1133+
return dpctl.tensor.pow(other, self)
12751134

12761135
def __rrshift__(self, other):
1277-
return _dispatch_binary_elementwise2(other, "bitwise_right_shift", self)
1136+
return dpctl.tensor.bitwise_right_shift(other, self)
12781137

12791138
def __rsub__(self, other):
1280-
return _dispatch_binary_elementwise2(other, "subtract", self)
1139+
return dpctl.tensor.subtract(other, self)
12811140

12821141
def __rtruediv__(self, other):
1283-
return _dispatch_binary_elementwise2(other, "divide", self)
1142+
return dpctl.tensor.divide(other, self)
12841143

12851144
def __rxor__(self, other):
1286-
return _dispatch_binary_elementwise2(other, "bitwise_xor", self)
1145+
return dpctl.tensor.bitwise_xor(other, self)
12871146

12881147
def __iadd__(self, other):
12891148
from ._elementwise_funcs import add

0 commit comments

Comments
 (0)