Skip to content

Commit 5febef5

Browse files
committed
Only allow floating-point dtypes in the array API __pow__ and __truediv__
See data-apis/array-api#221.
1 parent 6379138 commit 5febef5

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

numpy/_array_api/_array_object.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,8 @@ def __pow__(self: Array, other: Union[int, float, Array], /) -> Array:
503503

504504
if isinstance(other, (int, float, bool)):
505505
other = self._promote_scalar(other)
506+
if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
507+
raise TypeError('Only floating-point dtypes are allowed in __pow__')
506508
# Note: NumPy's __pow__ does not follow type promotion rules for 0-d
507509
# arrays, so we use pow() here instead.
508510
return pow(self, other)
@@ -548,6 +550,8 @@ def __truediv__(self: Array, other: Union[int, float, Array], /) -> Array:
548550
"""
549551
if isinstance(other, (int, float, bool)):
550552
other = self._promote_scalar(other)
553+
if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
554+
raise TypeError('Only floating-point dtypes are allowed in __truediv__')
551555
self, other = self._normalize_two_args(self, other)
552556
res = self._array.__truediv__(other._array)
553557
return self.__class__._new(res)
@@ -744,6 +748,8 @@ def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array:
744748
"""
745749
if isinstance(other, (int, float, bool)):
746750
other = self._promote_scalar(other)
751+
if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
752+
raise TypeError('Only floating-point dtypes are allowed in __pow__')
747753
self._array.__ipow__(other._array)
748754
return self
749755

@@ -756,6 +762,8 @@ def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array:
756762

757763
if isinstance(other, (int, float, bool)):
758764
other = self._promote_scalar(other)
765+
if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
766+
raise TypeError('Only floating-point dtypes are allowed in __pow__')
759767
# Note: NumPy's __pow__ does not follow the spec type promotion rules
760768
# for 0-d arrays, so we use pow() here instead.
761769
return pow(other, self)
@@ -810,6 +818,8 @@ def __itruediv__(self: Array, other: Union[int, float, Array], /) -> Array:
810818
"""
811819
if isinstance(other, (int, float, bool)):
812820
other = self._promote_scalar(other)
821+
if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
822+
raise TypeError('Only floating-point dtypes are allowed in __truediv__')
813823
self._array.__itruediv__(other._array)
814824
return self
815825

@@ -820,6 +830,8 @@ def __rtruediv__(self: Array, other: Union[int, float, Array], /) -> Array:
820830
"""
821831
if isinstance(other, (int, float, bool)):
822832
other = self._promote_scalar(other)
833+
if self.dtype not in _floating_dtypes or other.dtype not in _floating_dtypes:
834+
raise TypeError('Only floating-point dtypes are allowed in __truediv__')
823835
self, other = self._normalize_two_args(self, other)
824836
res = self._array.__rtruediv__(other._array)
825837
return self.__class__._new(res)

0 commit comments

Comments
 (0)