39
39
40
40
41
41
import numpy
42
+ from numpy .core .numeric import normalize_axis_tuple
42
43
43
44
import dpnp
44
45
from dpnp .dpnp_algo import *
@@ -66,9 +67,9 @@ def dot(a, b, out=None):
66
67
67
68
Parameters
68
69
----------
69
- a : {dpnp_array , usm_ndarray, scalar}
70
+ a : {dpnp.ndarray , usm_ndarray, scalar}
70
71
First input array. Both inputs `a` and `b` can not be scalars at the same time.
71
- b : {dpnp_array , usm_ndarray, scalar}
72
+ b : {dpnp.ndarray , usm_ndarray, scalar}
72
73
Second input array. Both inputs `a` and `b` can not be scalars at the same time.
73
74
out : {dpnp.ndarray, usm_ndarray}, optional
74
75
Alternative output array in which to place the result. It must have
@@ -404,42 +405,152 @@ def outer(x1, x2, out=None):
404
405
return call_origin (numpy .outer , x1 , x2 , out = out )
405
406
406
407
407
- def tensordot (x1 , x2 , axes = 2 ):
408
- """
408
+ def tensordot (a , b , axes = 2 ):
409
+ r """
409
410
Compute tensor dot product along specified axes.
410
411
411
412
For full documentation refer to :obj:`numpy.tensordot`.
412
413
413
- Limitations
414
- -----------
415
- Parameters `x1` and `x2` are supported as :obj:`dpnp.ndarray`.
416
- Keyword argument `kwargs` is currently unsupported.
417
- Parameter `axes` is supported only with value ``1``.
418
- Otherwise the functions will be executed sequentially on CPU.
419
- Input array data types are limited by supported DPNP :ref:`Data types`.
414
+ Parameters
415
+ ----------
416
+ a : {dpnp.ndarray, usm_ndarray, scalar}
417
+ First input array. Both inputs `a` and `b` can not be scalars at the same time.
418
+ b : {dpnp.ndarray, usm_ndarray, scalar}
419
+ Second input array. Both inputs `a` and `b` can not be scalars at the same time.
420
+ axes : int or (2,) array_like
421
+ * integer_like
422
+ If an int `N`, sum over the last `N` axes of `a` and the first `N` axes
423
+ of `b` in order. The sizes of the corresponding axes must match.
424
+ * (2,) array_like
425
+ Or, a list of axes to be summed over, first sequence applying to `a`,
426
+ second to `b`. Both elements array_like must be of the same length.
427
+
428
+ Returns
429
+ -------
430
+ out : dpnp.ndarray
431
+ Returns the tensordot product of `a` and `b`.
420
432
421
433
See Also
422
434
--------
423
435
:obj:`dpnp.dot` : Returns the dot product.
424
436
:obj:`dpnp.einsum` : Evaluates the Einstein summation convention on the operands.
425
437
438
+ Notes
439
+ -----
440
+ Three common use cases are:
441
+ * ``axes = 0`` : tensor product :math:`a \otimes b`
442
+ * ``axes = 1`` : tensor dot product :math:`a \cdot b`
443
+ * ``axes = 2`` : (default) tensor double contraction :math:`a:b`
444
+
445
+ When `axes` is integer, the sequence for evaluation will be: first
446
+ the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
447
+ Nth axis in `b` last.
448
+
449
+ When there is more than one axis to sum over - and they are not the last
450
+ (first) axes of `a` (`b`) - the argument `axes` should consist of
451
+ two sequences of the same length, with the first axis to sum over given
452
+ first in both sequences, the second axis second, and so forth.
453
+
454
+ The shape of the result consists of the non-contracted axes of the
455
+ first tensor, followed by the non-contracted axes of the second.
456
+
426
457
Examples
427
458
--------
428
459
>>> import dpnp as np
429
460
>>> a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
430
461
>>> b = np.array([1, 2, 3])
431
- >>> result = np.tensordot(a, b, 1)
432
- >>> [x for x in result]
433
- [14, 32, 50]
462
+ >>> np.tensordot(a, b, 1)
463
+ array([14, 32, 50])
464
+
465
+ >>> a = np.arange(60.).reshape(3,4,5)
466
+ >>> b = np.arange(24.).reshape(4,3,2)
467
+ >>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
468
+ >>> c.shape
469
+ (5, 2)
470
+ >>> c
471
+ array([[4400., 4730.],
472
+ [4532., 4874.],
473
+ [4664., 5018.],
474
+ [4796., 5162.],
475
+ [4928., 5306.]])
476
+
477
+ A slower but equivalent way of computing the same...
478
+
479
+ >>> d = np.zeros((5,2))
480
+ >>> for i in range(5):
481
+ ... for j in range(2):
482
+ ... for k in range(3):
483
+ ... for n in range(4):
484
+ ... d[i,j] += a[k,n,i] * b[n,k,j]
485
+ >>> c == d
486
+ array([[ True, True],
487
+ [ True, True],
488
+ [ True, True],
489
+ [ True, True],
490
+ [ True, True]])
434
491
435
492
"""
436
493
437
- x1_desc = dpnp .get_dpnp_descriptor (x1 , copy_when_nondefault_queue = False )
438
- x2_desc = dpnp .get_dpnp_descriptor (x2 , copy_when_nondefault_queue = False )
439
- if x1_desc and x2_desc and (axes == 1 ):
440
- return dpnp_tensordot_not_implemented (x1_desc , x2_desc ) # dpnp_matmul
494
+ dpnp .check_supported_arrays_type (a , b , scalar_type = True )
441
495
442
- return call_origin (numpy .tensordot , x1 , x2 , axes )
496
+ if dpnp .isscalar (a ):
497
+ a = dpnp .array (a , sycl_queue = b .sycl_queue , usm_type = b .usm_type )
498
+ elif dpnp .isscalar (b ):
499
+ b = dpnp .array (b , sycl_queue = a .sycl_queue , usm_type = a .usm_type )
500
+
501
+ try :
502
+ iter (axes )
503
+ except Exception :
504
+ if not isinstance (axes , int ):
505
+ raise TypeError ("Axes must be an integer." )
506
+ axes_a = tuple (range (- axes , 0 ))
507
+ axes_b = tuple (range (0 , axes ))
508
+ else :
509
+ if len (axes ) != 2 :
510
+ raise ValueError ("Axes must consist of two sequences." )
511
+
512
+ axes_a , axes_b = axes
513
+ axes_a = (axes_a ,) if dpnp .isscalar (axes_a ) else axes_a
514
+ axes_b = (axes_b ,) if dpnp .isscalar (axes_b ) else axes_b
515
+
516
+ if len (axes_a ) != len (axes_b ):
517
+ raise ValueError ("Axes length mismatch." )
518
+
519
+ a_shape = a .shape
520
+ b_shape = b .shape
521
+ for axis_a , axis_b in zip (axes_a , axes_b ):
522
+ if a_shape [axis_a ] != b_shape [axis_b ]:
523
+ raise ValueError (
524
+ "shape of input arrays is not similar at requested axes."
525
+ )
526
+
527
+ # Make the axes non-negative
528
+ a_ndim = a .ndim
529
+ b_ndim = b .ndim
530
+ axes_a = normalize_axis_tuple (axes_a , a_ndim , "axis" )
531
+ axes_b = normalize_axis_tuple (axes_b , b_ndim , "axis" )
532
+
533
+ # Move the axes to sum over, to the end of "a"
534
+ notin = tuple (k for k in range (a_ndim ) if k not in axes_a )
535
+ newaxes_a = notin + axes_a
536
+ N1 = int (numpy .prod ([a_shape [ax ] for ax in notin ]))
537
+ N2 = int (numpy .prod ([a_shape [ax ] for ax in axes_a ]))
538
+ newshape_a = (N1 , N2 )
539
+ olda = [a_shape [axis ] for axis in notin ]
540
+
541
+ # Move the axes to sum over, to the front of "b"
542
+ notin = tuple (k for k in range (b_ndim ) if k not in axes_b )
543
+ newaxes_b = tuple (axes_b + notin )
544
+ N1 = int (numpy .prod ([b_shape [ax ] for ax in axes_b ]))
545
+ N2 = int (numpy .prod ([b_shape [ax ] for ax in notin ]))
546
+ newshape_b = (N1 , N2 )
547
+ oldb = [b_shape [axis ] for axis in notin ]
548
+
549
+ at = a .transpose (newaxes_a ).reshape (newshape_a )
550
+ bt = b .transpose (newaxes_b ).reshape (newshape_b )
551
+ res = dpnp .matmul (at , bt )
552
+
553
+ return res .reshape (olda + oldb )
443
554
444
555
445
556
def vdot (a , b ):
@@ -450,11 +561,11 @@ def vdot(a, b):
450
561
451
562
Parameters
452
563
----------
453
- a : {dpnp_array , usm_ndarray, scalar}
564
+ a : {dpnp.ndarray , usm_ndarray, scalar}
454
565
First input array. Both inputs `a` and `b` can not be
455
566
scalars at the same time. If `a` is complex, the complex
456
567
conjugate is taken before the calculation of the dot product.
457
- b : {dpnp_array , usm_ndarray, scalar}
568
+ b : {dpnp.ndarray , usm_ndarray, scalar}
458
569
Second input array. Both inputs `a` and `b` can not be
459
570
scalars at the same time.
460
571
0 commit comments