Skip to content

Commit 20f90ed

Browse files
vtavanaantonwolfy
authored andcommitted
implement dpnp.tensordot (#1699)
* implement dpnp.tensordot * update doc string * address comments * fix doc string * update scaling factor * add TODO comment --------- Co-authored-by: Anton <[email protected]>
1 parent 2c1fc3d commit 20f90ed

File tree

10 files changed

+298
-76
lines changed

10 files changed

+298
-76
lines changed

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 132 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040

4141
import numpy
42+
from numpy.core.numeric import normalize_axis_tuple
4243

4344
import dpnp
4445
from dpnp.dpnp_algo import *
@@ -66,9 +67,9 @@ def dot(a, b, out=None):
6667
6768
Parameters
6869
----------
69-
a : {dpnp_array, usm_ndarray, scalar}
70+
a : {dpnp.ndarray, usm_ndarray, scalar}
7071
First input array. Both inputs `a` and `b` can not be scalars at the same time.
71-
b : {dpnp_array, usm_ndarray, scalar}
72+
b : {dpnp.ndarray, usm_ndarray, scalar}
7273
Second input array. Both inputs `a` and `b` can not be scalars at the same time.
7374
out : {dpnp.ndarray, usm_ndarray}, optional
7475
Alternative output array in which to place the result. It must have
@@ -404,42 +405,152 @@ def outer(x1, x2, out=None):
404405
return call_origin(numpy.outer, x1, x2, out=out)
405406

406407

407-
def tensordot(x1, x2, axes=2):
408-
"""
408+
def tensordot(a, b, axes=2):
409+
r"""
409410
Compute tensor dot product along specified axes.
410411
411412
For full documentation refer to :obj:`numpy.tensordot`.
412413
413-
Limitations
414-
-----------
415-
Parameters `x1` and `x2` are supported as :obj:`dpnp.ndarray`.
416-
Keyword argument `kwargs` is currently unsupported.
417-
Parameter `axes` is supported only with value ``1``.
418-
Otherwise the functions will be executed sequentially on CPU.
419-
Input array data types are limited by supported DPNP :ref:`Data types`.
414+
Parameters
415+
----------
416+
a : {dpnp.ndarray, usm_ndarray, scalar}
417+
First input array. Both inputs `a` and `b` can not be scalars at the same time.
418+
b : {dpnp.ndarray, usm_ndarray, scalar}
419+
Second input array. Both inputs `a` and `b` can not be scalars at the same time.
420+
axes : int or (2,) array_like
421+
* integer_like
422+
If an int `N`, sum over the last `N` axes of `a` and the first `N` axes
423+
of `b` in order. The sizes of the corresponding axes must match.
424+
* (2,) array_like
425+
Or, a list of axes to be summed over, first sequence applying to `a`,
426+
second to `b`. Both elements array_like must be of the same length.
427+
428+
Returns
429+
-------
430+
out : dpnp.ndarray
431+
Returns the tensordot product of `a` and `b`.
420432
421433
See Also
422434
--------
423435
:obj:`dpnp.dot` : Returns the dot product.
424436
:obj:`dpnp.einsum` : Evaluates the Einstein summation convention on the operands.
425437
438+
Notes
439+
-----
440+
Three common use cases are:
441+
* ``axes = 0`` : tensor product :math:`a \otimes b`
442+
* ``axes = 1`` : tensor dot product :math:`a \cdot b`
443+
* ``axes = 2`` : (default) tensor double contraction :math:`a:b`
444+
445+
When `axes` is integer, the sequence for evaluation will be: first
446+
the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
447+
Nth axis in `b` last.
448+
449+
When there is more than one axis to sum over - and they are not the last
450+
(first) axes of `a` (`b`) - the argument `axes` should consist of
451+
two sequences of the same length, with the first axis to sum over given
452+
first in both sequences, the second axis second, and so forth.
453+
454+
The shape of the result consists of the non-contracted axes of the
455+
first tensor, followed by the non-contracted axes of the second.
456+
426457
Examples
427458
--------
428459
>>> import dpnp as np
429460
>>> a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
430461
>>> b = np.array([1, 2, 3])
431-
>>> result = np.tensordot(a, b, 1)
432-
>>> [x for x in result]
433-
[14, 32, 50]
462+
>>> np.tensordot(a, b, 1)
463+
array([14, 32, 50])
464+
465+
>>> a = np.arange(60.).reshape(3,4,5)
466+
>>> b = np.arange(24.).reshape(4,3,2)
467+
>>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
468+
>>> c.shape
469+
(5, 2)
470+
>>> c
471+
array([[4400., 4730.],
472+
[4532., 4874.],
473+
[4664., 5018.],
474+
[4796., 5162.],
475+
[4928., 5306.]])
476+
477+
A slower but equivalent way of computing the same...
478+
479+
>>> d = np.zeros((5,2))
480+
>>> for i in range(5):
481+
... for j in range(2):
482+
... for k in range(3):
483+
... for n in range(4):
484+
... d[i,j] += a[k,n,i] * b[n,k,j]
485+
>>> c == d
486+
array([[ True, True],
487+
[ True, True],
488+
[ True, True],
489+
[ True, True],
490+
[ True, True]])
434491
435492
"""
436493

437-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
438-
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False)
439-
if x1_desc and x2_desc and (axes == 1):
440-
return dpnp_tensordot_not_implemented(x1_desc, x2_desc) # dpnp_matmul
494+
dpnp.check_supported_arrays_type(a, b, scalar_type=True)
441495

442-
return call_origin(numpy.tensordot, x1, x2, axes)
496+
if dpnp.isscalar(a):
497+
a = dpnp.array(a, sycl_queue=b.sycl_queue, usm_type=b.usm_type)
498+
elif dpnp.isscalar(b):
499+
b = dpnp.array(b, sycl_queue=a.sycl_queue, usm_type=a.usm_type)
500+
501+
try:
502+
iter(axes)
503+
except Exception:
504+
if not isinstance(axes, int):
505+
raise TypeError("Axes must be an integer.")
506+
axes_a = tuple(range(-axes, 0))
507+
axes_b = tuple(range(0, axes))
508+
else:
509+
if len(axes) != 2:
510+
raise ValueError("Axes must consist of two sequences.")
511+
512+
axes_a, axes_b = axes
513+
axes_a = (axes_a,) if dpnp.isscalar(axes_a) else axes_a
514+
axes_b = (axes_b,) if dpnp.isscalar(axes_b) else axes_b
515+
516+
if len(axes_a) != len(axes_b):
517+
raise ValueError("Axes length mismatch.")
518+
519+
a_shape = a.shape
520+
b_shape = b.shape
521+
for axis_a, axis_b in zip(axes_a, axes_b):
522+
if a_shape[axis_a] != b_shape[axis_b]:
523+
raise ValueError(
524+
"shape of input arrays is not similar at requested axes."
525+
)
526+
527+
# Make the axes non-negative
528+
a_ndim = a.ndim
529+
b_ndim = b.ndim
530+
axes_a = normalize_axis_tuple(axes_a, a_ndim, "axis")
531+
axes_b = normalize_axis_tuple(axes_b, b_ndim, "axis")
532+
533+
# Move the axes to sum over, to the end of "a"
534+
notin = tuple(k for k in range(a_ndim) if k not in axes_a)
535+
newaxes_a = notin + axes_a
536+
N1 = int(numpy.prod([a_shape[ax] for ax in notin]))
537+
N2 = int(numpy.prod([a_shape[ax] for ax in axes_a]))
538+
newshape_a = (N1, N2)
539+
olda = [a_shape[axis] for axis in notin]
540+
541+
# Move the axes to sum over, to the front of "b"
542+
notin = tuple(k for k in range(b_ndim) if k not in axes_b)
543+
newaxes_b = tuple(axes_b + notin)
544+
N1 = int(numpy.prod([b_shape[ax] for ax in axes_b]))
545+
N2 = int(numpy.prod([b_shape[ax] for ax in notin]))
546+
newshape_b = (N1, N2)
547+
oldb = [b_shape[axis] for axis in notin]
548+
549+
at = a.transpose(newaxes_a).reshape(newshape_a)
550+
bt = b.transpose(newaxes_b).reshape(newshape_b)
551+
res = dpnp.matmul(at, bt)
552+
553+
return res.reshape(olda + oldb)
443554

444555

445556
def vdot(a, b):
@@ -450,11 +561,11 @@ def vdot(a, b):
450561
451562
Parameters
452563
----------
453-
a : {dpnp_array, usm_ndarray, scalar}
564+
a : {dpnp.ndarray, usm_ndarray, scalar}
454565
First input array. Both inputs `a` and `b` can not be
455566
scalars at the same time. If `a` is complex, the complex
456567
conjugate is taken before the calculation of the dot product.
457-
b : {dpnp_array, usm_ndarray, scalar}
568+
b : {dpnp.ndarray, usm_ndarray, scalar}
458569
Second input array. Both inputs `a` and `b` can not be
459570
scalars at the same time.
460571

dpnp/dpnp_iface_sorting.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# cython: language_level=3
2-
# distutils: language = c++
31
# -*- coding: utf-8 -*-
42
# *****************************************************************************
53
# Copyright (c) 2016-2024, Intel Corporation

tests/helper.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99

1010
def assert_dtype_allclose(
11-
dpnp_arr, numpy_arr, check_type=True, check_only_type_kind=False
11+
dpnp_arr,
12+
numpy_arr,
13+
check_type=True,
14+
check_only_type_kind=False,
15+
factor=8,
1216
):
1317
"""
1418
Assert DPNP and NumPy array based on maximum dtype resolution of input arrays
@@ -28,6 +32,7 @@ def assert_dtype_allclose(
2832
The 'check_only_type_kind' parameter (False by default) asserts only equal type kinds
2933
for all data types supported by DPNP when set to True.
3034
It is effective only when 'check_type' is also set to True.
35+
The parameter `factor` scales the resolution used for comparing the arrays.
3136
3237
"""
3338

@@ -44,7 +49,7 @@ def assert_dtype_allclose(
4449
if is_inexact(numpy_arr)
4550
else -dpnp.inf
4651
)
47-
tol = 8 * max(tol_dpnp, tol_numpy)
52+
tol = factor * max(tol_dpnp, tol_numpy)
4853
assert_allclose(dpnp_arr.asnumpy(), numpy_arr, atol=tol, rtol=tol)
4954
if check_type:
5055
numpy_arr_dtype = numpy_arr.dtype

tests/skipped_tests.tbl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,10 +335,6 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test
335335
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge
336336
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large
337337
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two
338-
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_tensordot_zero_dim
339-
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot
340-
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_int_axes
341-
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_list_axes
342338

343339
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_broadcast_not_allowed
344340
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_is_equal

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -437,10 +437,6 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test
437437
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge
438438
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large
439439
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two
440-
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot
441-
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_int_axes
442-
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_list_axes
443-
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_tensordot_zero_dim
444440

445441
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_broadcast_not_allowed
446442
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_is_equal

0 commit comments

Comments
 (0)