Skip to content

Commit acfab08

Browse files
authored
Merge pull request #103 from ev-br/binops_reject_ndarrays
BUG: add missing `_check_type_device` calls
2 parents 884f3b8 + 8bc3de3 commit acfab08

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

array_api_strict/_array_object.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,7 @@ def __imod__(self, other: Array | float, /) -> Array:
11041104
"""
11051105
Performs the operation __imod__.
11061106
"""
1107+
self._check_type_device(other)
11071108
other = self._check_allowed_dtypes(other, "real numeric", "__imod__")
11081109
if other is NotImplemented:
11091110
return other
@@ -1126,6 +1127,7 @@ def __imul__(self, other: Array | complex, /) -> Array:
11261127
"""
11271128
Performs the operation __imul__.
11281129
"""
1130+
self._check_type_device(other)
11291131
other = self._check_allowed_dtypes(other, "numeric", "__imul__")
11301132
if other is NotImplemented:
11311133
return other
@@ -1148,6 +1150,7 @@ def __ior__(self, other: Array | int, /) -> Array:
11481150
"""
11491151
Performs the operation __ior__.
11501152
"""
1153+
self._check_type_device(other)
11511154
other = self._check_allowed_dtypes(other, "integer or boolean", "__ior__")
11521155
if other is NotImplemented:
11531156
return other
@@ -1170,6 +1173,7 @@ def __ipow__(self, other: Array | complex, /) -> Array:
11701173
"""
11711174
Performs the operation __ipow__.
11721175
"""
1176+
self._check_type_device(other)
11731177
other = self._check_allowed_dtypes(other, "numeric", "__ipow__")
11741178
if other is NotImplemented:
11751179
return other
@@ -1182,6 +1186,7 @@ def __rpow__(self, other: Array | complex, /) -> Array:
11821186
"""
11831187
from ._elementwise_functions import pow # type: ignore[attr-defined]
11841188

1189+
self._check_type_device(other)
11851190
other = self._check_allowed_dtypes(other, "numeric", "__rpow__")
11861191
if other is NotImplemented:
11871192
return other
@@ -1193,6 +1198,7 @@ def __irshift__(self, other: Array | int, /) -> Array:
11931198
"""
11941199
Performs the operation __irshift__.
11951200
"""
1201+
self._check_type_device(other)
11961202
other = self._check_allowed_dtypes(other, "integer", "__irshift__")
11971203
if other is NotImplemented:
11981204
return other
@@ -1215,6 +1221,7 @@ def __isub__(self, other: Array | complex, /) -> Array:
12151221
"""
12161222
Performs the operation __isub__.
12171223
"""
1224+
self._check_type_device(other)
12181225
other = self._check_allowed_dtypes(other, "numeric", "__isub__")
12191226
if other is NotImplemented:
12201227
return other
@@ -1237,6 +1244,7 @@ def __itruediv__(self, other: Array | complex, /) -> Array:
12371244
"""
12381245
Performs the operation __itruediv__.
12391246
"""
1247+
self._check_type_device(other)
12401248
other = self._check_allowed_dtypes(other, "floating-point", "__itruediv__")
12411249
if other is NotImplemented:
12421250
return other
@@ -1259,6 +1267,7 @@ def __ixor__(self, other: Array | int, /) -> Array:
12591267
"""
12601268
Performs the operation __ixor__.
12611269
"""
1270+
self._check_type_device(other)
12621271
other = self._check_allowed_dtypes(other, "integer or boolean", "__ixor__")
12631272
if other is NotImplemented:
12641273
return other

array_api_strict/tests/test_array_object.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,13 @@ def _array_vals():
344344
getattr(x, _op)(y)
345345
else:
346346
assert_raises(TypeError, lambda: getattr(x, _op)(y))
347+
# finally, test that array op ndarray raises
348+
# XXX: as long as there is __array__ or __buffer__, __rop__s
349+
# still return ndarrays
350+
if not _op.startswith("__r"):
351+
with assert_raises(TypeError):
352+
getattr(x, _op)(y._array)
353+
347354

348355
for op, dtypes in unary_op_dtypes.items():
349356
for a in _array_vals():

0 commit comments

Comments
 (0)