Skip to content

Commit 2224ce2

Browse files
authored
Add parameter out in dpnp.dot() (#1327)
1 parent 648612d commit 2224ce2

File tree

9 files changed

+113
-72
lines changed

9 files changed

+113
-72
lines changed

dpnp/dpnp_algo/dpnp_algo_linearalgebra.pyx

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# cython: language_level=3
22
# -*- coding: utf-8 -*-
33
# *****************************************************************************
4-
# Copyright (c) 2016-2020, Intel Corporation
4+
# Copyright (c) 2016-2023, Intel Corporation
55
# All rights reserved.
66
#
77
# Redistribution and use in source and binary forms, with or without
@@ -65,8 +65,9 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_2in_1out_matmul_t)(c_dpctl.DPCTLSyclQue
6565
const shape_elem_type *, const shape_elem_type * ,
6666
const c_dpctl.DPCTLEventVectorRef)
6767

68-
cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, utils.dpnp_descriptor in_array2):
69-
68+
cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1,
69+
utils.dpnp_descriptor in_array2,
70+
utils.dpnp_descriptor out=None):
7071
cdef shape_type_c shape1, shape2
7172

7273
shape1 = in_array1.shape
@@ -78,6 +79,7 @@ cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, utils.dpnp
7879

7980
# get the FPTR data structure
8081
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_DOT_EXT, param1_type, param2_type)
82+
cdef utils.dpnp_descriptor result
8183

8284
ndim1 = in_array1.ndim
8385
ndim2 = in_array2.ndim
@@ -89,7 +91,7 @@ cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, utils.dpnp
8991
elif ndim1 == 1 and ndim2 == 1:
9092
result_shape = ()
9193
elif ndim1 == 1: # ndim2 > 1
92-
result_shape = shape2[:-1]
94+
result_shape = shape2[::-2] if ndim2 == 2 else shape2[::2]
9395
elif ndim2 == 1: # ndim1 > 1
9496
result_shape = shape1[:-1]
9597
else:
@@ -101,13 +103,24 @@ cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, utils.dpnp
101103

102104
result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(in_array1, in_array2)
103105

104-
# create result array with type given by FPTR data
105-
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape,
106-
kernel_data.return_type,
107-
None,
108-
device=result_sycl_device,
109-
usm_type=result_usm_type,
110-
sycl_queue=result_sycl_queue)
106+
if out is None:
107+
# create result array with type given by FPTR data
108+
result = utils.create_output_descriptor(result_shape,
109+
kernel_data.return_type,
110+
None,
111+
device=result_sycl_device,
112+
usm_type=result_usm_type,
113+
sycl_queue=result_sycl_queue)
114+
else:
115+
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
116+
if out.dtype != result_type:
117+
utils.checker_throw_value_error('dot', 'out.dtype', out.dtype, result_type)
118+
if out.shape != result_shape:
119+
utils.checker_throw_value_error('dot', 'out.shape', out.shape, result_shape)
120+
121+
result = out
122+
123+
utils.get_common_usm_allocation(in_array1, result) # check USM allocation is common
111124

112125
cdef shape_type_c result_strides = utils.strides_to_vector(result.strides, result.shape)
113126
cdef shape_type_c in_array1_shape = in_array1.shape

dpnp/dpnp_array.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,8 @@ def diagonal(input, offset=0, axis1=0, axis2=1):
592592

593593
return dpnp.diagonal(input, offset, axis1, axis2)
594594

595-
# 'dot',
595+
def dot(self, other, out=None):
596+
return dpnp.dot(self, other, out)
596597

597598
@property
598599
def dtype(self):

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@
4444
from dpnp.dpnp_algo import *
4545
from dpnp.dpnp_utils import *
4646
import dpnp
47-
import dpnp.config as config
4847

4948
import numpy
49+
import dpctl.tensor as dpt
5050

5151

5252
__all__ = [
@@ -62,18 +62,25 @@
6262
]
6363

6464

65-
def dot(x1, x2, **kwargs):
65+
def dot(x1, x2, out=None, **kwargs):
6666
"""
67-
Returns the dot product of `x1` and `x2`.
67+
Dot product of `x1` and `x2`.
6868
6969
For full documentation refer to :obj:`numpy.dot`.
7070
71+
Returns
72+
-------
73+
y : dpnp.ndarray
74+
Returns the dot product of `x1` and `x2`.
75+
If `out` is given, then it is returned.
76+
7177
Limitations
7278
-----------
73-
Parameters ``x1`` and ``x2`` are supported as :obj:`dpnp.ndarray` of the same type.
74-
Keyword arguments ``kwargs`` are currently unsupported.
75-
Otherwise the functions will be executed sequentially on CPU.
76-
Input array data types are limited by supported DPNP :ref:`Data types`.
79+
Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray`
80+
or :class:`dpctl.tensor.usm_ndarray`, but both `x1` and `x2` can not be scalars at the same time.
81+
Keyword argument ``kwargs`` is currently unsupported.
82+
Otherwise the functions will be executed sequentially on CPU.
83+
Input array data types are limited by supported DPNP :ref:`Data types`.
7784
7885
See Also
7986
--------
@@ -82,31 +89,37 @@ def dot(x1, x2, **kwargs):
8289
8390
Examples
8491
--------
85-
>>> import dpnp as np
86-
>>> np.dot(3, 4)
87-
12
88-
>>> a = np.array([1, 2, 3])
89-
>>> b = np.array([1, 2, 3])
90-
>>> np.dot(a, b)
92+
>>> import dpnp as dp
93+
>>> a = dp.array([1, 2, 3])
94+
>>> b = dp.array([1, 2, 3])
95+
>>> dp.dot(a, b)
9196
14
9297
9398
"""
9499

95-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False)
96-
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False)
97-
if x1_desc and x2_desc and not kwargs:
98-
# TODO: remove fallback with scalars when muliply backend func will support strides
99-
if(x1_desc.ndim == 0 and x2_desc.strides is not None
100-
or x2_desc.ndim == 0 and x1_desc.strides is not None):
101-
pass
102-
elif (x1_desc.ndim >= 1 and x2_desc.ndim > 1 and x1_desc.shape[-1] != x2_desc.shape[-2]):
103-
pass
104-
elif (x1_desc.ndim > 0 and x2_desc.ndim == 1 and x1_desc.shape[-1] != x2_desc.shape[0]):
105-
pass
106-
else:
107-
return dpnp_dot(x1_desc, x2_desc).get_pyobj()
100+
if kwargs:
101+
pass
102+
elif dpnp.isscalar(x1) and dpnp.isscalar(x2):
103+
# at least either x1 or x2 has to be an array
104+
pass
105+
else:
106+
# get USM type and queue to copy scalar from the host memory into a USM allocation
107+
usm_type, queue = get_usm_allocations([x1, x2]) if dpnp.isscalar(x1) or dpnp.isscalar(x2) else (None, None)
108+
109+
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False,
110+
alloc_usm_type=usm_type, alloc_queue=queue)
111+
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False,
112+
alloc_usm_type=usm_type, alloc_queue=queue)
113+
if x1_desc and x2_desc:
114+
if out is not None:
115+
if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)):
116+
raise TypeError("return array must be of supported array type")
117+
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False)
118+
else:
119+
out_desc = None
120+
return dpnp_dot(x1_desc, x2_desc, out=out_desc).get_pyobj()
108121

109-
return call_origin(numpy.dot, x1, x2, **kwargs)
122+
return call_origin(numpy.dot, x1, x2, out=out, **kwargs)
110123

111124

112125
def einsum(*args, **kwargs):

tests/skipped_tests.tbl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -610,10 +610,6 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumLarge_param_9_{opt
610610
tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWithScalar::test_scalar_float
611611
tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWithScalar::test_scalar_int
612612
tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_invalid_sub1
613-
tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_64_{shape=((2,), (2, 4)), trans_a=True, trans_b=True}::test_dot
614-
tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_65_{shape=((2,), (2, 4)), trans_a=True, trans_b=False}::test_dot
615-
tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_66_{shape=((2,), (2, 4)), trans_a=False, trans_b=True}::test_dot
616-
tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_67_{shape=((2,), (2, 4)), trans_a=False, trans_b=False}::test_dot
617613
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge
618614
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large
619615
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -812,10 +812,6 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWith
812812
tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWithScalar::test_scalar_int
813813
tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_invalid_sub1
814814
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_dot
815-
tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_64_{shape=((2,), (2, 4)), trans_a=True, trans_b=True}::test_dot
816-
tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_65_{shape=((2,), (2, 4)), trans_a=True, trans_b=False}::test_dot
817-
tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_66_{shape=((2,), (2, 4)), trans_a=False, trans_b=True}::test_dot
818-
tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_67_{shape=((2,), (2, 4)), trans_a=False, trans_b=False}::test_dot
819815
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge
820816
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large
821817
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two
@@ -827,7 +823,6 @@ tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transpose
827823
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot
828824
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_int_axes
829825
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_list_axes
830-
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::
831826
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_tensordot_zero_dim
832827
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_dot_with_out_f_contiguous
833828
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_multidim_vdot

tests/test_dot.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import pytest
2+
from .helper import get_all_dtypes
23

34
import dpnp as inp
45

56
import numpy
7+
from numpy.testing import (
8+
assert_allclose,
9+
assert_array_equal
10+
)
611

712

8-
@pytest.mark.parametrize("type",
9-
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
10-
ids=['float64', 'float32', 'int64', 'int32'])
13+
@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True))
1114
def test_dot_ones(type):
1215
n = 10**5
1316
a = numpy.ones(n, dtype=type)
@@ -17,12 +20,10 @@ def test_dot_ones(type):
1720

1821
result = inp.dot(ia, ib)
1922
expected = numpy.dot(a, b)
20-
numpy.testing.assert_array_equal(expected, result)
23+
assert_array_equal(expected, result)
2124

2225

23-
@pytest.mark.parametrize("type",
24-
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
25-
ids=['float64', 'float32', 'int64', 'int32'])
26+
@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True))
2627
def test_dot_arange(type):
2728
n = 10**2
2829
m = 10**3
@@ -33,12 +34,10 @@ def test_dot_arange(type):
3334

3435
result = inp.dot(ia, ib)
3536
expected = numpy.dot(a, b)
36-
numpy.testing.assert_allclose(expected, result)
37+
assert_allclose(expected, result)
3738

3839

39-
@pytest.mark.parametrize("type",
40-
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
41-
ids=['float64', 'float32', 'int64', 'int32'])
40+
@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True))
4241
def test_multi_dot(type):
4342
n = 16
4443
a = inp.reshape(inp.arange(n, dtype=type), (4, 4))
@@ -53,4 +52,4 @@ def test_multi_dot(type):
5352

5453
result = inp.linalg.multi_dot([a, b, c, d])
5554
expected = numpy.linalg.multi_dot([a1, b1, c1, d1])
56-
numpy.testing.assert_array_equal(expected, result)
55+
assert_array_equal(expected, result)

tests/test_sycl_queue.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def test_1in_1out(func, data, device):
297297
x = dpnp.array(data, device=device)
298298
result = getattr(dpnp, func)(x)
299299

300-
numpy.testing.assert_array_equal(result, expected)
300+
assert_array_equal(result, expected)
301301

302302
expected_queue = x.get_array().sycl_queue
303303
result_queue = result.get_array().sycl_queue
@@ -320,6 +320,9 @@ def test_1in_1out(func, data, device):
320320
pytest.param("divide",
321321
[0., 1., 2., 3., 4.],
322322
[4., 4., 4., 4., 4.]),
323+
pytest.param("dot",
324+
[[0., 1., 2.], [3., 4., 5.]],
325+
[[4., 4.], [4., 4.], [4., 4.]]),
323326
pytest.param("floor_divide",
324327
[1., 2., 3., 4.],
325328
[2.5, 2.5, 2.5, 2.5]),
@@ -364,7 +367,7 @@ def test_2in_1out(func, data1, data2, device):
364367
x2 = dpnp.array(data2, device=device)
365368
result = getattr(dpnp, func)(x1, x2)
366369

367-
numpy.testing.assert_array_equal(result, expected)
370+
assert_array_equal(result, expected)
368371

369372
assert_sycl_queue_equal(result.sycl_queue, x1.sycl_queue)
370373
assert_sycl_queue_equal(result.sycl_queue, x2.sycl_queue)
@@ -539,6 +542,9 @@ def test_random_state(func, args, kwargs, device, usm_type):
539542
pytest.param("divide",
540543
[0., 1., 2., 3., 4.],
541544
[4., 4., 4., 4., 4.]),
545+
pytest.param("dot",
546+
[[0., 1., 2.], [3., 4., 5.]],
547+
[[4., 4.], [4., 4.], [4., 4.]]),
542548
pytest.param("floor_divide",
543549
[1., 2., 3., 4.],
544550
[2.5, 2.5, 2.5, 2.5]),
@@ -571,20 +577,20 @@ def test_random_state(func, args, kwargs, device, usm_type):
571577
def test_out(func, data1, data2, device):
572578
x1_orig = numpy.array(data1)
573579
x2_orig = numpy.array(data2)
574-
expected = numpy.empty(x1_orig.size)
575-
numpy.add(x1_orig, x2_orig, out=expected)
580+
np_out = getattr(numpy, func)(x1_orig, x2_orig)
581+
expected = numpy.empty_like(np_out)
582+
getattr(numpy, func)(x1_orig, x2_orig, out=expected)
576583

577584
x1 = dpnp.array(data1, device=device)
578585
x2 = dpnp.array(data2, device=device)
579-
result = dpnp.empty(x1.size, device=device)
580-
dpnp.add(x1, x2, out=result)
586+
dp_out = getattr(dpnp, func)(x1, x2)
587+
result = dpnp.empty_like(dp_out)
588+
getattr(dpnp, func)(x1, x2, out=result)
581589

582-
numpy.testing.assert_array_equal(result, expected)
590+
assert_array_equal(result, expected)
583591

584-
expected_queue = x1.get_array().sycl_queue
585-
result_queue = result.get_array().sycl_queue
586-
587-
assert_sycl_queue_equal(result_queue, expected_queue)
592+
assert_sycl_queue_equal(result.sycl_queue, x1.sycl_queue)
593+
assert_sycl_queue_equal(result.sycl_queue, x2.sycl_queue)
588594

589595

590596
@pytest.mark.parametrize("device",

tests/test_usm_type.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,22 @@ def test_meshgrid(usm_type_x, usm_type_y):
154154
z = dp.meshgrid(x, y)
155155
assert z[0].usm_type == usm_type_x
156156
assert z[1].usm_type == usm_type_y
157+
158+
@pytest.mark.parametrize(
159+
"func,data1,data2",
160+
[
161+
pytest.param("dot",
162+
[[0., 1., 2.], [3., 4., 5.]],
163+
[[4., 4.], [4., 4.], [4., 4.]]),
164+
],
165+
)
166+
@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
167+
@pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types)
168+
def test_2in_1out(func, data1, data2, usm_type_x, usm_type_y):
169+
x = dp.array(data1, usm_type = usm_type_x)
170+
y = dp.array(data2, usm_type = usm_type_y)
171+
z = getattr(dp, func)(x, y)
172+
173+
assert x.usm_type == usm_type_x
174+
assert y.usm_type == usm_type_y
175+
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])

tests/third_party/cupy/linalg_tests/test_product.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
'trans_a': [True, False],
3232
'trans_b': [True, False],
3333
}))
34-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
3534
@testing.gpu
3635
class TestDot(unittest.TestCase):
3736

0 commit comments

Comments
 (0)