Skip to content

Commit 815ec86

Browse files
committed
address comments
1 parent d3ba4f3 commit 815ec86

File tree

4 files changed

+142
-102
lines changed

4 files changed

+142
-102
lines changed

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,19 +82,14 @@ def dot(a, b, out=None):
8282
Returns the dot product of `a` and `b`.
8383
If `out` is given, then it is returned.
8484
85-
Limitations
86-
-----------
87-
Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray`
88-
or :class:`dpctl.tensor.usm_ndarray`, but both `x1` and `x2` can not be scalars at the same time.
89-
Keyword argument ``kwargs`` is currently unsupported.
90-
Otherwise the functions will be executed sequentially on CPU.
91-
Input array data types are limited by supported DPNP :ref:`Data types`.
92-
9385
See Also
9486
--------
9587
:obj:`dpnp.ndarray.dot` : Equivalent method.
9688
:obj:`dpnp.tensordot` : Sum products over arbitrary axes.
9789
:obj:`dpnp.vdot` : Complex-conjugating dot product.
90+
:obj:`dpnp.einsum` : Einstein summation convention.
91+
:obj:`dpnp.matmul` : Matrix product of two arrays.
92+
:obj:`dpnp.linalg.multi_dot` : Chained dot product.
9893
9994
Examples
10095
--------
@@ -135,15 +130,19 @@ def dot(a, b, out=None):
135130
raise ValueError("Only C-contiguous array is acceptable.")
136131

137132
if dpnp.isscalar(a) or dpnp.isscalar(b):
133+
# TODO: investigate usage of axpy (axpy_batch) or scal
134+
# functions from BLAS here instead of dpnp.multiply
138135
return dpnp.multiply(a, b, out=out)
139136
elif a.ndim == 0 or b.ndim == 0:
137+
# TODO: investigate usage of axpy (axpy_batch) or scal
138+
# functions from BLAS here instead of dpnp.multiply
140139
return dpnp.multiply(a, b, out=out)
141140
elif a.ndim == 1 and b.ndim == 1:
142141
return dpnp_dot(a, b, out=out)
143142
elif a.ndim == 2 and b.ndim == 2:
144143
# NumPy does not allow casting even if it is safe
145144
return dpnp.matmul(a, b, out=out, casting="no")
146-
elif a.ndim > 1 and b.ndim == 1:
145+
elif a.ndim == 1 or b.ndim == 1:
147146
# NumPy does not allow casting even if it is safe
148147
return dpnp.matmul(a, b, out=out, casting="no")
149148
else:

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 76 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -34,61 +34,34 @@
3434
__all__ = ["dpnp_dot", "dpnp_matmul"]
3535

3636

37-
def _op_res_dtype(*arrays, dtype, casting, sycl_queue):
37+
def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None):
3838
"""
39-
_op_res_dtype(*arrays, dtype, casting, sycl_queue)
40-
41-
Determines the output array data type and an intermediate data type
42-
used in performing calculations related to a specific math function.
43-
If dtype is ``None``, the output array data type of the operation is
44-
determined based on the Promotion Type Rule and device capabilities.
45-
Otherwise, `dtype` is used as output array dtype, if input arrays
46-
can cast to it according to the casting rule determined. If casting
47-
cannot be done, a ``TypeError`` is raised.
48-
The intermediate data type is the data type used for performing the math
49-
function calculations. If output array dtype is a floating-point data type,
50-
it is also used for the intermediate data type. If output array dtype is an
51-
integral data type, the default floating point data type of the device where
52-
input arrays are allocated on are used for intermediate data type.
53-
54-
Parameters
55-
----------
56-
arrays : {dpnp.ndarray, usm_ndarray}
57-
Input arrays.
58-
dtype : dtype
59-
If not ``None``, data type of the output array.
60-
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
61-
Controls what kind of data casting may occur.
62-
sycl_queue : {SyclQueue}
63-
A SYCL queue to use for determining default floating point datat type.
39+
Creating a copy of input array if needed.
6440
65-
Returns
66-
-------
67-
op_dtype, res_dtype :
68-
`op_dtype` is the data type used in performing math function calculations.
69-
The input arrays of the math function are cast to `op_dtype` and then
70-
the calculations are performed.
71-
`res_dtype` is the output data type. When the result is obtained, it is cast
72-
to `res_dtype`.
41+
If `contig_copy` is ``True``, a C-contiguous copy of input array is returned.
42+
In this case, the copy array has the input array data type unless `dtype` is
43+
determined.
44+
If `contig_copy` is ``False`` and input array data type is different than `dtype`,
45+
a C-contiguous copy of input array with specified `dtype` is returned.
7346
7447
"""
7548

76-
res_dtype = dpnp.result_type(*arrays)
77-
default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue)
78-
79-
if dtype is not None:
80-
if dpnp.can_cast(res_dtype, dtype, casting=casting):
81-
res_dtype = dtype
82-
else:
83-
raise TypeError(
84-
f"Cannot cast ufunc 'matmul' output from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}"
85-
)
86-
87-
op_dtype = (
88-
res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype
89-
)
49+
if contig_copy:
50+
copy = contig_copy
51+
else:
52+
copy = x.dtype != dtype if dtype is not None else False
9053

91-
return op_dtype, res_dtype
54+
if copy:
55+
x_copy = dpnp.empty_like(x, dtype=dtype, order="C")
56+
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
57+
src=dpnp.get_usm_ndarray(x),
58+
dst=x_copy.get_array(),
59+
sycl_queue=x.sycl_queue,
60+
)
61+
dep_events.append(copy_ev)
62+
host_events.append(ht_copy_ev)
63+
return x_copy
64+
return x
9265

9366

9467
def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list):
@@ -153,34 +126,61 @@ def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list):
153126
return ht_blas_ev, ht_tasks_list, res
154127

155128

156-
def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None):
129+
def _op_res_dtype(*arrays, dtype, casting, sycl_queue):
157130
"""
158-
Creating a copy of input array if needed.
131+
_op_res_dtype(*arrays, dtype, casting, sycl_queue)
159132
160-
If `contig_copy` is ``True``, a C-contiguous copy of input array is returned.
161-
In this case, the copy array has the input array data type unless `dtype` is
162-
determined.
163-
If `contig_copy` is ``False`` and input array data type is different than `dtype`,
164-
a C-contiguous copy of input array with specified `dtype` is returned.
133+
Determines the output array data type and an intermediate data type
134+
used in performing calculations related to a specific math function.
135+
If dtype is ``None``, the output array data type of the operation is
136+
determined based on the Promotion Type Rule and device capabilities.
137+
Otherwise, `dtype` is used as output array dtype, if input arrays
138+
can cast to it according to the casting rule determined. If casting
139+
cannot be done, a ``TypeError`` is raised.
140+
The intermediate data type is the data type used for performing the math
141+
function calculations. If output array dtype is a floating-point data type,
142+
it is also used for the intermediate data type. If output array dtype is an
143+
integral data type, the default floating point data type of the device where
144+
input arrays are allocated on are used for intermediate data type.
145+
146+
Parameters
147+
----------
148+
arrays : {dpnp.ndarray, usm_ndarray}
149+
Input arrays.
150+
dtype : dtype
151+
If not ``None``, data type of the output array.
152+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
153+
Controls what kind of data casting may occur.
154+
sycl_queue : {SyclQueue}
155+
A SYCL queue to use for determining default floating point datat type.
156+
157+
Returns
158+
-------
159+
op_dtype, res_dtype :
160+
`op_dtype` is the data type used in performing math function calculations.
161+
The input arrays of the math function are cast to `op_dtype` and then
162+
the calculations are performed.
163+
`res_dtype` is the output data type. When the result is obtained, it is cast
164+
to `res_dtype`.
165165
166166
"""
167167

168-
if contig_copy:
169-
copy = contig_copy
170-
else:
171-
copy = x.dtype != dtype if dtype is not None else False
168+
res_dtype = dpnp.result_type(*arrays)
169+
default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue)
172170

173-
if copy:
174-
x_copy = dpnp.empty_like(x, dtype=dtype, order="C")
175-
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
176-
src=dpnp.get_usm_ndarray(x),
177-
dst=x_copy.get_array(),
178-
sycl_queue=x.sycl_queue,
179-
)
180-
dep_events.append(copy_ev)
181-
host_events.append(ht_copy_ev)
182-
return x_copy
183-
return x
171+
if dtype is not None:
172+
if dpnp.can_cast(res_dtype, dtype, casting=casting):
173+
res_dtype = dtype
174+
else:
175+
raise TypeError(
176+
f"Cannot cast ufunc 'matmul' output from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}"
177+
)
178+
179+
op_dtype = (
180+
res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype
181+
)
182+
183+
return op_dtype, res_dtype
184184

185185

186186
def dpnp_dot(
@@ -394,6 +394,11 @@ def dpnp_matmul(
394394
dtype=gemm_dtype,
395395
)
396396

397+
# TODO: investigate usage of gemv (gemv_batch) function
398+
# from BLAS when one of the inputs is a vector to
399+
# gain performance.
400+
# TODO: investigate usage of syrk function from BLAS in
401+
# case of a.T @ a and a @ a.T to gain performance.
397402
if x1_is_2D and x2_is_2D:
398403
ht_blas_ev, _ = bi._gemm(
399404
exec_q,

tests/third_party/cupy/linalg_tests/test_eigenvalue.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,6 @@ def _get_hermitian(xp, a, UPLO):
1515
return xp.tril(a) + xp.tril(a, k=-1).swapaxes(-2, -1).conj()
1616

1717

18-
# TODO:
19-
# remove once dpnp.dot and dpnp.matmul support complex types
20-
def _wrap_as_numpy_array(xp, a):
21-
return a.asnumpy() if xp is cupy else a
22-
23-
2418
@testing.parameterize(
2519
*testing.product(
2620
{
@@ -57,20 +51,12 @@ def test_eigh(self, xp, dtype):
5751
else:
5852
tol = 1e-5
5953

60-
# TODO: remove _wrap_as_numpy_array() once @ support complex types
61-
testing.assert_allclose(
62-
_wrap_as_numpy_array(xp, A) @ _wrap_as_numpy_array(xp, v),
63-
_wrap_as_numpy_array(xp, v)
64-
@ numpy.diag(_wrap_as_numpy_array(xp, w)),
65-
atol=tol,
66-
rtol=tol,
67-
)
54+
testing.assert_allclose(A @ v, v @ xp.diag(w), atol=tol, rtol=tol)
6855

6956
# Check if v @ vt is an identity matrix
7057
testing.assert_allclose(
71-
_wrap_as_numpy_array(xp, v)
72-
@ _wrap_as_numpy_array(xp, v).swapaxes(-2, -1).conj(),
73-
numpy.identity(_wrap_as_numpy_array(xp, A).shape[-1], _dtype),
58+
v @ v.swapaxes(-2, -1).conj(),
59+
xp.identity(A.shape[-1], _dtype),
7460
atol=tol,
7561
rtol=tol,
7662
)
@@ -121,11 +107,6 @@ def test_eigh_complex_batched(self, xp, dtype):
121107
# them through the eigen equation A*v=w*v.
122108
A = _get_hermitian(xp, a, self.UPLO)
123109

124-
# TODO: remove _wrap_as_numpy_array() once dpnp.dot() support complex types
125-
A = _wrap_as_numpy_array(xp, A)
126-
v = _wrap_as_numpy_array(xp, v)
127-
w = _wrap_as_numpy_array(xp, w)
128-
129110
for i in range(a.shape[0]):
130111
testing.assert_allclose(
131112
A[i].dot(v[i]), w[i] * v[i], rtol=1e-5, atol=1e-5

tests/third_party/cupy/math_tests/test_matmul.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,61 @@ def test_cupy_matmul(self, xp, dtype1):
7373
return xp.matmul(x1, x2)
7474

7575

76+
@testing.parameterize(
77+
*testing.product(
78+
{
79+
"shape_pair": [
80+
# dot test
81+
((2, 3), (3, 4), (2, 4)),
82+
# ((0,), (0,), (0,)),
83+
# matmul test
84+
((5, 3, 2), (5, 2, 4), (5, 3, 4)),
85+
((0, 3, 2), (0, 2, 4), (0, 3, 4)),
86+
],
87+
}
88+
)
89+
)
90+
class TestMatmulOut(unittest.TestCase):
91+
@testing.for_all_dtypes(name="dtype1")
92+
@testing.for_all_dtypes(name="dtype2")
93+
@testing.numpy_cupy_allclose(
94+
rtol=1e-3, atol=1e-3, accept_error=TypeError # required for uint8
95+
)
96+
def test_cupy_matmul_noncontiguous(self, xp, dtype1, dtype2):
97+
x1 = testing.shaped_arange(self.shape_pair[0], xp, dtype1)
98+
x2 = testing.shaped_arange(self.shape_pair[1], xp, dtype2)
99+
out = xp.zeros(self.shape_pair[2], dtype=dtype1)[::-1]
100+
ret = xp.matmul(x1, x2, out=out)
101+
assert ret is out
102+
return ret
103+
104+
@testing.for_all_dtypes(name="dtype1")
105+
@testing.for_all_dtypes(name="dtype2")
106+
@testing.numpy_cupy_allclose(rtol=1e-3, atol=1e-3) # required for uint8
107+
def test_cupy_matmul_out_cast(self, xp, dtype1, dtype2):
108+
x1 = testing.shaped_arange(self.shape_pair[0], xp, dtype1)
109+
x2 = testing.shaped_arange(self.shape_pair[1], xp, dtype2)
110+
out = xp.zeros(self.shape_pair[2], dtype=bool)
111+
ret = xp.matmul(x1, x2, out=out, casting="unsafe")
112+
assert ret is out
113+
return ret
114+
115+
116+
class TestMatmulOutOverlap:
117+
@pytest.mark.parametrize(
118+
"shape",
119+
[
120+
(900, 900),
121+
(2, 600, 600),
122+
],
123+
)
124+
@testing.for_dtypes([numpy.int32, numpy.float64])
125+
@testing.numpy_cupy_allclose(rtol=1e-5, atol=1e-5)
126+
def test_overlap_both(self, xp, dtype, shape):
127+
a = xp.ones(shape, dtype=dtype)
128+
return xp.matmul(a, a, out=a)
129+
130+
76131
@testing.parameterize(
77132
*testing.product(
78133
{

0 commit comments

Comments
 (0)