|
39 | 39 |
|
40 | 40 |
|
41 | 41 | import numpy
|
| 42 | +from numpy.core.numeric import normalize_axis_tuple |
42 | 43 |
|
43 | 44 | import dpnp
|
44 | 45 | from dpnp.dpnp_algo import *
|
@@ -404,42 +405,152 @@ def outer(x1, x2, out=None):
|
404 | 405 | return call_origin(numpy.outer, x1, x2, out=out)
|
405 | 406 |
|
406 | 407 |
|
407 |
| -def tensordot(x1, x2, axes=2): |
408 |
| - """ |
| 408 | +def tensordot(a, b, axes=2): |
| 409 | + r""" |
409 | 410 | Compute tensor dot product along specified axes.
|
410 | 411 |
|
411 | 412 | For full documentation refer to :obj:`numpy.tensordot`.
|
412 | 413 |
|
413 |
| - Limitations |
414 |
| - ----------- |
415 |
| - Parameters `x1` and `x2` are supported as :obj:`dpnp.ndarray`. |
416 |
| - Keyword argument `kwargs` is currently unsupported. |
417 |
| - Parameter `axes` is supported only with value ``1``. |
418 |
| - Otherwise the functions will be executed sequentially on CPU. |
419 |
| - Input array data types are limited by supported DPNP :ref:`Data types`. |
| 414 | + Parameters |
| 415 | + ---------- |
| 416 | + a : {dpnp_array, usm_ndarray, scalar} |
| 417 | + First input array. Both inputs `a` and `b` can not be scalars at the same time. |
| 418 | + b : {dpnp_array, usm_ndarray, scalar} |
| 419 | + Second input array. Both inputs `a` and `b` can not be scalars at the same time. |
| 420 | + axes : int or (2,) array_like |
| 421 | + * integer_like |
| 422 | + If an int N, sum over the last N axes of `a` and the first N axes |
| 423 | + of `b` in order. The sizes of the corresponding axes must match. |
| 424 | + * (2,) array_like |
| 425 | + Or, a list of axes to be summed over, first sequence applying to `a`, |
| 426 | + second to `b`. Both elements array_like must be of the same length. |
| 427 | +
|
| 428 | + Returns |
| 429 | + ------- |
| 430 | + out : dpnp.ndarray |
| 431 | + Returns the tensordot product of `a` and `b`. |
420 | 432 |
|
421 | 433 | See Also
|
422 | 434 | --------
|
423 | 435 | :obj:`dpnp.dot` : Returns the dot product.
|
424 | 436 | :obj:`dpnp.einsum` : Evaluates the Einstein summation convention on the operands.
|
425 | 437 |
|
| 438 | + Notes |
| 439 | + ----- |
| 440 | + Three common use cases are: |
| 441 | + * ``axes = 0`` : tensor product :math:`a\\otimes b` |
| 442 | + * ``axes = 1`` : tensor dot product :math:`a\\cdot b` |
| 443 | + * ``axes = 2`` : (default) tensor double contraction :math:`a:b` |
| 444 | +
|
| 445 | + When `axes` is integer, the sequence for evaluation will be: first |
| 446 | + the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and |
| 447 | + Nth axis in `b` last. |
| 448 | +
|
| 449 | + When there is more than one axis to sum over - and they are not the last |
| 450 | + (first) axes of `a` (`b`) - the argument `axes` should consist of |
| 451 | + two sequences of the same length, with the first axis to sum over given |
| 452 | + first in both sequences, the second axis second, and so forth. |
| 453 | +
|
| 454 | + The shape of the result consists of the non-contracted axes of the |
| 455 | + first tensor, followed by the non-contracted axes of the second. |
| 456 | +
|
426 | 457 | Examples
|
427 | 458 | --------
|
428 | 459 | >>> import dpnp as np
|
429 | 460 | >>> a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
430 | 461 | >>> b = np.array([1, 2, 3])
|
431 |
| - >>> result = np.tensordot(a, b, 1) |
432 |
| - >>> [x for x in result] |
433 |
| - [14, 32, 50] |
| 462 | + >>> np.tensordot(a, b, 1) |
| 463 | + array([14, 32, 50]) |
| 464 | +
|
| 465 | + >>> a = np.arange(60.).reshape(3,4,5) |
| 466 | + >>> b = np.arange(24.).reshape(4,3,2) |
| 467 | + >>> c = np.tensordot(a,b, axes=([1,0],[0,1])) |
| 468 | + >>> c.shape |
| 469 | + (5, 2) |
| 470 | + >>> c |
| 471 | + array([[4400., 4730.], |
| 472 | + [4532., 4874.], |
| 473 | + [4664., 5018.], |
| 474 | + [4796., 5162.], |
| 475 | + [4928., 5306.]]) |
| 476 | +
|
| 477 | + A slower but equivalent way of computing the same... |
| 478 | +
|
| 479 | + >>> d = np.zeros((5,2)) |
| 480 | + >>> for i in range(5): |
| 481 | + ... for j in range(2): |
| 482 | + ... for k in range(3): |
| 483 | + ... for n in range(4): |
| 484 | + ... d[i,j] += a[k,n,i] * b[n,k,j] |
| 485 | + >>> c == d |
| 486 | + array([[ True, True], |
| 487 | + [ True, True], |
| 488 | + [ True, True], |
| 489 | + [ True, True], |
| 490 | + [ True, True]]) |
434 | 491 |
|
435 | 492 | """
|
436 | 493 |
|
437 |
| - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) |
438 |
| - x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False) |
439 |
| - if x1_desc and x2_desc and (axes == 1): |
440 |
| - return dpnp_tensordot_not_implemented(x1_desc, x2_desc) # dpnp_matmul |
| 494 | + dpnp.check_supported_arrays_type(a, b, scalar_type=True) |
441 | 495 |
|
442 |
| - return call_origin(numpy.tensordot, x1, x2, axes) |
| 496 | + if dpnp.isscalar(a): |
| 497 | + a = dpnp.array(a, sycl_queue=b.sycl_queue, usm_type=b.usm_type) |
| 498 | + elif dpnp.isscalar(b): |
| 499 | + b = dpnp.array(b, sycl_queue=a.sycl_queue, usm_type=a.usm_type) |
| 500 | + |
| 501 | + try: |
| 502 | + iter(axes) |
| 503 | + except Exception: |
| 504 | + if not isinstance(axes, int): |
| 505 | + raise ValueError("Axes must be an integer.") |
| 506 | + axes_a = tuple(range(-axes, 0)) |
| 507 | + axes_b = tuple(range(0, axes)) |
| 508 | + else: |
| 509 | + if len(axes) != 2: |
| 510 | + raise ValueError("Axes must consist of two sequences.") |
| 511 | + |
| 512 | + axes_a, axes_b = axes |
| 513 | + axes_a = (axes_a,) if dpnp.isscalar(axes_a) else axes_a |
| 514 | + axes_b = (axes_b,) if dpnp.isscalar(axes_b) else axes_b |
| 515 | + |
| 516 | + if len(axes_a) != len(axes_b): |
| 517 | + raise ValueError("Axes length mismatch.") |
| 518 | + |
| 519 | + a_shape = a.shape |
| 520 | + b_shape = b.shape |
| 521 | + for axis_a, axis_b in zip(axes_a, axes_b): |
| 522 | + if a_shape[axis_a] != b_shape[axis_b]: |
| 523 | + raise ValueError( |
| 524 | + "shape of input arrays is not similar at requested axes." |
| 525 | + ) |
| 526 | + |
| 527 | + # Make the axes non-negative |
| 528 | + a_ndim = a.ndim |
| 529 | + b_ndim = b.ndim |
| 530 | + axes_a = normalize_axis_tuple(axes_a, a_ndim, "axis") |
| 531 | + axes_b = normalize_axis_tuple(axes_b, b_ndim, "axis") |
| 532 | + |
| 533 | + # Move the axes to sum over, to the end of "a" |
| 534 | + notin = tuple(k for k in range(a_ndim) if k not in axes_a) |
| 535 | + newaxes_a = notin + axes_a |
| 536 | + N1 = int(numpy.prod([a_shape[ax] for ax in notin])) |
| 537 | + N2 = int(numpy.prod([a_shape[ax] for ax in axes_a])) |
| 538 | + newshape_a = (N1, N2) |
| 539 | + olda = [a_shape[axis] for axis in notin] |
| 540 | + |
| 541 | + # Move the axes to sum over, to the front of "b" |
| 542 | + notin = tuple(k for k in range(b_ndim) if k not in axes_b) |
| 543 | + newaxes_b = tuple(axes_b + notin) |
| 544 | + N1 = int(numpy.prod([b_shape[ax] for ax in axes_b])) |
| 545 | + N2 = int(numpy.prod([b_shape[ax] for ax in notin])) |
| 546 | + newshape_b = (N1, N2) |
| 547 | + oldb = [b_shape[axis] for axis in notin] |
| 548 | + |
| 549 | + at = a.transpose(newaxes_a).reshape(newshape_a) |
| 550 | + bt = b.transpose(newaxes_b).reshape(newshape_b) |
| 551 | + res = dpnp.matmul(at, bt) |
| 552 | + |
| 553 | + return res.reshape(olda + oldb) |
443 | 554 |
|
444 | 555 |
|
445 | 556 | def vdot(a, b):
|
|
0 commit comments