Skip to content

Commit 0178080

Browse files
committed
Fix in-place operators to not recreate the wrapper class
1 parent cad21e9 commit 0178080

File tree

1 file changed

+24
-26
lines changed

1 file changed

+24
-26
lines changed

numpy/_array_api/_array_object.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -568,10 +568,8 @@ def __iadd__(self: array, other: Union[int, float, array], /) -> array:
568568
"""
569569
if isinstance(other, (int, float, bool)):
570570
other = self._promote_scalar(other)
571-
res = self._array.__iadd__(other._array)
572-
if res.dtype != self.dtype:
573-
raise RuntimeError
574-
return self.__class__._new(res)
571+
self._array.__iadd__(other._array)
572+
return self
575573

576574
@np.errstate(all='ignore')
577575
def __radd__(self: array, other: Union[int, float, array], /) -> array:
@@ -590,8 +588,8 @@ def __iand__(self: array, other: Union[int, bool, array], /) -> array:
590588
"""
591589
if isinstance(other, (int, float, bool)):
592590
other = self._promote_scalar(other)
593-
res = self._array.__iand__(other._array)
594-
return self.__class__._new(res)
591+
self._array.__iand__(other._array)
592+
return self
595593

596594
def __rand__(self: array, other: Union[int, bool, array], /) -> array:
597595
"""
@@ -610,8 +608,8 @@ def __ifloordiv__(self: array, other: Union[int, float, array], /) -> array:
610608
"""
611609
if isinstance(other, (int, float, bool)):
612610
other = self._promote_scalar(other)
613-
res = self._array.__ifloordiv__(other._array)
614-
return self.__class__._new(res)
611+
self._array.__ifloordiv__(other._array)
612+
return self
615613

616614
@np.errstate(all='ignore')
617615
def __rfloordiv__(self: array, other: Union[int, float, array], /) -> array:
@@ -630,8 +628,8 @@ def __ilshift__(self: array, other: Union[int, array], /) -> array:
630628
"""
631629
if isinstance(other, (int, float, bool)):
632630
other = self._promote_scalar(other)
633-
res = self._array.__ilshift__(other._array)
634-
return self.__class__._new(res)
631+
self._array.__ilshift__(other._array)
632+
return self
635633

636634
def __rlshift__(self: array, other: Union[int, array], /) -> array:
637635
"""
@@ -675,8 +673,8 @@ def __imod__(self: array, other: Union[int, float, array], /) -> array:
675673
"""
676674
if isinstance(other, (int, float, bool)):
677675
other = self._promote_scalar(other)
678-
res = self._array.__imod__(other._array)
679-
return self.__class__._new(res)
676+
self._array.__imod__(other._array)
677+
return self
680678

681679
@np.errstate(all='ignore')
682680
def __rmod__(self: array, other: Union[int, float, array], /) -> array:
@@ -696,8 +694,8 @@ def __imul__(self: array, other: Union[int, float, array], /) -> array:
696694
"""
697695
if isinstance(other, (int, float, bool)):
698696
other = self._promote_scalar(other)
699-
res = self._array.__imul__(other._array)
700-
return self.__class__._new(res)
697+
self._array.__imul__(other._array)
698+
return self
701699

702700
@np.errstate(all='ignore')
703701
def __rmul__(self: array, other: Union[int, float, array], /) -> array:
@@ -716,8 +714,8 @@ def __ior__(self: array, other: Union[int, bool, array], /) -> array:
716714
"""
717715
if isinstance(other, (int, float, bool)):
718716
other = self._promote_scalar(other)
719-
res = self._array.__ior__(other._array)
720-
return self.__class__._new(res)
717+
self._array.__ior__(other._array)
718+
return self
721719

722720
def __ror__(self: array, other: Union[int, bool, array], /) -> array:
723721
"""
@@ -736,8 +734,8 @@ def __ipow__(self: array, other: Union[int, float, array], /) -> array:
736734
"""
737735
if isinstance(other, (int, float, bool)):
738736
other = self._promote_scalar(other)
739-
res = self._array.__ipow__(other._array)
740-
return self.__class__._new(res)
737+
self._array.__ipow__(other._array)
738+
return self
741739

742740
@np.errstate(all='ignore')
743741
def __rpow__(self: array, other: Union[int, float, array], /) -> array:
@@ -758,8 +756,8 @@ def __irshift__(self: array, other: Union[int, array], /) -> array:
758756
"""
759757
if isinstance(other, (int, float, bool)):
760758
other = self._promote_scalar(other)
761-
res = self._array.__irshift__(other._array)
762-
return self.__class__._new(res)
759+
self._array.__irshift__(other._array)
760+
return self
763761

764762
def __rrshift__(self: array, other: Union[int, array], /) -> array:
765763
"""
@@ -781,8 +779,8 @@ def __isub__(self: array, other: Union[int, float, array], /) -> array:
781779
"""
782780
if isinstance(other, (int, float, bool)):
783781
other = self._promote_scalar(other)
784-
res = self._array.__isub__(other._array)
785-
return self.__class__._new(res)
782+
self._array.__isub__(other._array)
783+
return self
786784

787785
@np.errstate(all='ignore')
788786
def __rsub__(self: array, other: Union[int, float, array], /) -> array:
@@ -802,8 +800,8 @@ def __itruediv__(self: array, other: Union[int, float, array], /) -> array:
802800
"""
803801
if isinstance(other, (int, float, bool)):
804802
other = self._promote_scalar(other)
805-
res = self._array.__itruediv__(other._array)
806-
return self.__class__._new(res)
803+
self._array.__itruediv__(other._array)
804+
return self
807805

808806
@np.errstate(all='ignore')
809807
def __rtruediv__(self: array, other: Union[int, float, array], /) -> array:
@@ -822,8 +820,8 @@ def __ixor__(self: array, other: Union[int, bool, array], /) -> array:
822820
"""
823821
if isinstance(other, (int, float, bool)):
824822
other = self._promote_scalar(other)
825-
res = self._array.__ixor__(other._array)
826-
return self.__class__._new(res)
823+
self._array.__ixor__(other._array)
824+
return self
827825

828826
def __rxor__(self: array, other: Union[int, bool, array], /) -> array:
829827
"""

0 commit comments

Comments
 (0)