Skip to content

Commit 687b8ea

Browse files
authored
Add inplace matrix multiplication (#2147)
* Add support for inplace matrix multiplication * Raise ValueError exception per axes keyword * Add rendering documentation for reflected and inplace operations * Exclude __reduce__ method from rendering documentation * Align expection text for inplace matrix multiplication * Split too long line
1 parent 0517413 commit 687b8ea

File tree

4 files changed

+192
-61
lines changed

4 files changed

+192
-61
lines changed

doc/reference/ndarray.rst

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ Comparison operators:
249249
dpnp.ndarray.__eq__
250250
dpnp.ndarray.__ne__
251251

252-
Truth value of an array (:func:`bool()`):
252+
Truth value of an array (:class:`bool() <bool>`):
253253

254254
.. autosummary::
255255
:toctree: generated/
@@ -260,11 +260,11 @@ Truth value of an array (:func:`bool()`):
260260

261261
Truth-value testing of an array invokes
262262
:meth:`dpnp.ndarray.__bool__`, which raises an error if the number of
263-
elements in the array is larger than 1, because the truth value
263+
elements in the array is not 1, because the truth value
264264
of such arrays is ambiguous. Use :meth:`.any() <dpnp.ndarray.any>` and
265265
:meth:`.all() <dpnp.ndarray.all>` instead to be clear about what is meant
266-
in such cases. (If the number of elements is 0, the array evaluates
267-
to ``False``.)
266+
in such cases. (If you wish to check for whether an array is empty,
267+
use for example ``.size > 0``.)
268268

269269

270270
Unary operations:
@@ -300,6 +300,26 @@ Arithmetic:
300300
dpnp.ndarray.__xor__
301301

302302

303+
Arithmetic, reflected:
304+
305+
.. autosummary::
306+
:toctree: generated/
307+
:nosignatures:
308+
309+
dpnp.ndarray.__radd__
310+
dpnp.ndarray.__rsub__
311+
dpnp.ndarray.__rmul__
312+
dpnp.ndarray.__rtruediv__
313+
dpnp.ndarray.__rfloordiv__
314+
dpnp.ndarray.__rmod__
315+
dpnp.ndarray.__rpow__
316+
dpnp.ndarray.__rlshift__
317+
dpnp.ndarray.__rrshift__
318+
dpnp.ndarray.__rand__
319+
dpnp.ndarray.__ror__
320+
dpnp.ndarray.__rxor__
321+
322+
303323
Arithmetic, in-place:
304324

305325
.. autosummary::
@@ -326,6 +346,8 @@ Matrix Multiplication:
326346
:toctree: generated/
327347

328348
dpnp.ndarray.__matmul__
349+
dpnp.ndarray.__rmatmul__
350+
dpnp.ndarray.__imatmul__
329351

330352

331353
Special methods

dpnp/dpnp_array.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
# *****************************************************************************
2626

2727
import dpctl.tensor as dpt
28+
from dpctl.tensor._numpy_helper import AxisError
2829

2930
import dpnp
3031

@@ -205,6 +206,7 @@ def __bool__(self):
205206
return self._array_obj.__bool__()
206207

207208
# '__class__',
209+
# `__class_getitem__`,
208210

209211
def __complex__(self):
210212
return self._array_obj.__complex__()
@@ -335,6 +337,8 @@ def __getitem__(self, key):
335337
res._array_obj = item
336338
return res
337339

340+
# '__getstate__',
341+
338342
def __gt__(self, other):
339343
"""Return ``self>value``."""
340344
return dpnp.greater(self, other)
@@ -361,7 +365,31 @@ def __ilshift__(self, other):
361365
dpnp.left_shift(self, other, out=self)
362366
return self
363367

364-
# '__imatmul__',
368+
def __imatmul__(self, other):
369+
"""Return ``self@=value``."""
370+
371+
"""
372+
Unlike `matmul(a, b, out=a)` we ensure that the result is not broadcast
373+
if the result without `out` would have less dimensions than `a`.
374+
Since the signature of matmul is '(n?,k),(k,m?)->(n?,m?)' this is the
375+
case exactly when the second operand has both core dimensions.
376+
We have to enforce this check by passing the correct `axes=`.
377+
"""
378+
if self.ndim == 1:
379+
axes = [(-1,), (-2, -1), (-1,)]
380+
else:
381+
axes = [(-2, -1), (-2, -1), (-2, -1)]
382+
383+
try:
384+
dpnp.matmul(self, other, out=self, axes=axes)
385+
except AxisError:
386+
# AxisError should indicate that the axes argument didn't work out
387+
# which should mean the second operand not being 2 dimensional.
388+
raise ValueError(
389+
"inplace matrix multiplication requires the first operand to "
390+
"have at least one and the second at least two dimensions."
391+
)
392+
return self
365393

366394
def __imod__(self, other):
367395
"""Return ``self%=value``."""
@@ -469,9 +497,11 @@ def __pow__(self, other):
469497
return dpnp.power(self, other)
470498

471499
def __radd__(self, other):
500+
"""Return ``value+self``."""
472501
return dpnp.add(other, self)
473502

474503
def __rand__(self, other):
504+
"""Return ``value&self``."""
475505
return dpnp.bitwise_and(other, self)
476506

477507
# '__rdivmod__',
@@ -483,40 +513,51 @@ def __repr__(self):
483513
return dpt.usm_ndarray_repr(self._array_obj, prefix="array")
484514

485515
def __rfloordiv__(self, other):
516+
"""Return ``value//self``."""
486517
return dpnp.floor_divide(self, other)
487518

488519
def __rlshift__(self, other):
520+
"""Return ``value<<self``."""
489521
return dpnp.left_shift(other, self)
490522

491523
def __rmatmul__(self, other):
524+
"""Return ``value@self``."""
492525
return dpnp.matmul(other, self)
493526

494527
def __rmod__(self, other):
528+
"""Return ``value%self``."""
495529
return dpnp.remainder(other, self)
496530

497531
def __rmul__(self, other):
532+
"""Return ``value*self``."""
498533
return dpnp.multiply(other, self)
499534

500535
def __ror__(self, other):
536+
"""Return ``value|self``."""
501537
return dpnp.bitwise_or(other, self)
502538

503539
def __rpow__(self, other):
540+
"""Return ``value**self``."""
504541
return dpnp.power(other, self)
505542

506543
def __rrshift__(self, other):
544+
"""Return ``value>>self``."""
507545
return dpnp.right_shift(other, self)
508546

509547
def __rshift__(self, other):
510548
"""Return ``self>>value``."""
511549
return dpnp.right_shift(self, other)
512550

513551
def __rsub__(self, other):
552+
"""Return ``value-self``."""
514553
return dpnp.subtract(other, self)
515554

516555
def __rtruediv__(self, other):
556+
"""Return ``value/self``."""
517557
return dpnp.true_divide(other, self)
518558

519559
def __rxor__(self, other):
560+
"""Return ``value^self``."""
520561
return dpnp.bitwise_xor(other, self)
521562

522563
# '__setattr__',

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import dpctl.tensor._tensor_impl as ti
2929
import dpctl.utils as dpu
3030
import numpy
31-
from dpctl.tensor._numpy_helper import normalize_axis_tuple
31+
from dpctl.tensor._numpy_helper import AxisError, normalize_axis_tuple
3232
from dpctl.utils import ExecutionPlacementError
3333

3434
import dpnp
@@ -525,15 +525,15 @@ def _validate_internal(axes, i, ndim):
525525
)
526526

527527
if len(axes) != 1:
528-
raise ValueError(
528+
raise AxisError(
529529
f"Axes item {i} should be a tuple with a single element, or an integer."
530530
)
531531
else:
532532
iter = 2
533533
if not isinstance(axes, tuple):
534534
raise TypeError(f"Axes item {i} should be a tuple.")
535535
if len(axes) != 2:
536-
raise ValueError(
536+
raise AxisError(
537537
f"Axes item {i} should be a tuple with 2 elements."
538538
)
539539

@@ -563,7 +563,7 @@ def _validate_internal(axes, i, ndim):
563563

564564
if x1_ndim == 1 and x2_ndim == 1:
565565
if axes[2] != ():
566-
raise TypeError("Axes item 2 should be an empty tuple.")
566+
raise AxisError("Axes item 2 should be an empty tuple.")
567567
elif x1_ndim == 1 or x2_ndim == 1:
568568
axes[2] = _validate_internal(axes[2], 2, 1)
569569
else:

0 commit comments

Comments
 (0)