Skip to content

Commit 1379623

Browse files
committed
Fix the __imatmul__ method in the array API namespace
1 parent 0178080 commit 1379623

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

numpy/_array_api/_array_object.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -648,12 +648,21 @@ def __imatmul__(self: array, other: array, /) -> array:
648648
"""
649649
Performs the operation __imatmul__.
650650
"""
651+
# Note: NumPy does not implement __imatmul__.
652+
651653
if isinstance(other, (int, float, bool)):
652654
# matmul is not defined for scalars, but without this, we may get
653655
# the wrong error message from asarray.
654656
other = self._promote_scalar(other)
655-
res = self._array.__imatmul__(other._array)
656-
return self.__class__._new(res)
657+
# __imatmul__ can only be allowed when it would not change the shape
658+
# of self.
659+
other_shape = other.shape
660+
if self.shape == () or other_shape == ():
661+
raise ValueError("@= requires at least one dimension")
662+
if len(other_shape) == 1 or other_shape[-1] != other_shape[-2]:
663+
raise ValueError("@= cannot change the shape of the input array")
664+
self._array[:] = self._array.__matmul__(other._array)
665+
return self
657666

658667
def __rmatmul__(self: array, other: array, /) -> array:
659668
"""

0 commit comments

Comments
 (0)