Skip to content

Add inplace matrix multiplication #2147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions doc/reference/ndarray.rst
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ Comparison operators:
dpnp.ndarray.__eq__
dpnp.ndarray.__ne__

Truth value of an array (:func:`bool()`):
Truth value of an array (:class:`bool() <bool>`):

.. autosummary::
:toctree: generated/
Expand All @@ -260,11 +260,11 @@ Truth value of an array (:func:`bool()`):

Truth-value testing of an array invokes
:meth:`dpnp.ndarray.__bool__`, which raises an error if the number of
elements in the array is larger than 1, because the truth value
elements in the array is not 1, because the truth value
of such arrays is ambiguous. Use :meth:`.any() <dpnp.ndarray.any>` and
:meth:`.all() <dpnp.ndarray.all>` instead to be clear about what is meant
in such cases. (If the number of elements is 0, the array evaluates
to ``False``.)
in such cases. (If you wish to check for whether an array is empty,
use for example ``.size > 0``.)


Unary operations:
Expand Down Expand Up @@ -300,6 +300,26 @@ Arithmetic:
dpnp.ndarray.__xor__


Arithmetic, reflected:

.. autosummary::
:toctree: generated/
:nosignatures:

dpnp.ndarray.__radd__
dpnp.ndarray.__rsub__
dpnp.ndarray.__rmul__
dpnp.ndarray.__rtruediv__
dpnp.ndarray.__rfloordiv__
dpnp.ndarray.__rmod__
dpnp.ndarray.__rpow__
dpnp.ndarray.__rlshift__
dpnp.ndarray.__rrshift__
dpnp.ndarray.__rand__
dpnp.ndarray.__ror__
dpnp.ndarray.__rxor__


Arithmetic, in-place:

.. autosummary::
Expand All @@ -326,6 +346,8 @@ Matrix Multiplication:
:toctree: generated/

dpnp.ndarray.__matmul__
dpnp.ndarray.__rmatmul__
dpnp.ndarray.__imatmul__


Special methods
Expand Down
43 changes: 42 additions & 1 deletion dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# *****************************************************************************

import dpctl.tensor as dpt
from dpctl.tensor._numpy_helper import AxisError

import dpnp

Expand Down Expand Up @@ -205,6 +206,7 @@ def __bool__(self):
return self._array_obj.__bool__()

# '__class__',
# `__class_getitem__`,

def __complex__(self):
return self._array_obj.__complex__()
Expand Down Expand Up @@ -335,6 +337,8 @@ def __getitem__(self, key):
res._array_obj = item
return res

# '__getstate__',

def __gt__(self, other):
"""Return ``self>value``."""
return dpnp.greater(self, other)
Expand All @@ -361,7 +365,31 @@ def __ilshift__(self, other):
dpnp.left_shift(self, other, out=self)
return self

# '__imatmul__',
def __imatmul__(self, other):
"""Return ``self@=value``."""

"""
Unlike `matmul(a, b, out=a)` we ensure that the result is not broadcast
if the result without `out` would have less dimensions than `a`.
Since the signature of matmul is '(n?,k),(k,m?)->(n?,m?)' this is the
case exactly when the second operand has both core dimensions.
We have to enforce this check by passing the correct `axes=`.
"""
if self.ndim == 1:
axes = [(-1,), (-2, -1), (-1,)]
else:
axes = [(-2, -1), (-2, -1), (-2, -1)]

try:
dpnp.matmul(self, other, out=self, axes=axes)
except AxisError:
# AxisError should indicate that the axes argument didn't work out
# which should mean the second operand not being 2 dimensional.
raise ValueError(
"inplace matrix multiplication requires the first operand to "
"have at least one and the second at least two dimensions."
)
return self

def __imod__(self, other):
"""Return ``self%=value``."""
Expand Down Expand Up @@ -469,9 +497,11 @@ def __pow__(self, other):
return dpnp.power(self, other)

def __radd__(self, other):
"""Return ``value+self``."""
return dpnp.add(other, self)

def __rand__(self, other):
"""Return ``value&self``."""
return dpnp.bitwise_and(other, self)

# '__rdivmod__',
Expand All @@ -483,40 +513,51 @@ def __repr__(self):
return dpt.usm_ndarray_repr(self._array_obj, prefix="array")

def __rfloordiv__(self, other):
"""Return ``value//self``."""
return dpnp.floor_divide(self, other)

def __rlshift__(self, other):
"""Return ``value<<self``."""
return dpnp.left_shift(other, self)

def __rmatmul__(self, other):
"""Return ``value@self``."""
return dpnp.matmul(other, self)

def __rmod__(self, other):
"""Return ``value%self``."""
return dpnp.remainder(other, self)

def __rmul__(self, other):
"""Return ``value*self``."""
return dpnp.multiply(other, self)

def __ror__(self, other):
"""Return ``value|self``."""
return dpnp.bitwise_or(other, self)

def __rpow__(self, other):
"""Return ``value**self``."""
return dpnp.power(other, self)

def __rrshift__(self, other):
"""Return ``value>>self``."""
return dpnp.right_shift(other, self)

def __rshift__(self, other):
"""Return ``self>>value``."""
return dpnp.right_shift(self, other)

def __rsub__(self, other):
"""Return ``value-self``."""
return dpnp.subtract(other, self)

def __rtruediv__(self, other):
"""Return ``value/self``."""
return dpnp.true_divide(other, self)

def __rxor__(self, other):
"""Return ``value^self``."""
return dpnp.bitwise_xor(other, self)

# '__setattr__',
Expand Down
8 changes: 4 additions & 4 deletions dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import dpctl.tensor._tensor_impl as ti
import dpctl.utils as dpu
import numpy
from dpctl.tensor._numpy_helper import normalize_axis_tuple
from dpctl.tensor._numpy_helper import AxisError, normalize_axis_tuple
from dpctl.utils import ExecutionPlacementError

import dpnp
Expand Down Expand Up @@ -525,15 +525,15 @@ def _validate_internal(axes, i, ndim):
)

if len(axes) != 1:
raise ValueError(
raise AxisError(
f"Axes item {i} should be a tuple with a single element, or an integer."
)
else:
iter = 2
if not isinstance(axes, tuple):
raise TypeError(f"Axes item {i} should be a tuple.")
if len(axes) != 2:
raise ValueError(
raise AxisError(
f"Axes item {i} should be a tuple with 2 elements."
)

Expand Down Expand Up @@ -563,7 +563,7 @@ def _validate_internal(axes, i, ndim):

if x1_ndim == 1 and x2_ndim == 1:
if axes[2] != ():
raise TypeError("Axes item 2 should be an empty tuple.")
raise AxisError("Axes item 2 should be an empty tuple.")
elif x1_ndim == 1 or x2_ndim == 1:
axes[2] = _validate_internal(axes[2], 2, 1)
else:
Expand Down
Loading
Loading