Skip to content

Commit 63aab70

Browse files
committed
address comments
1 parent 8e4f733 commit 63aab70

File tree

3 files changed

+45
-97
lines changed

3 files changed

+45
-97
lines changed

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,7 @@
4343
import dpnp
4444
from dpnp.dpnp_algo import *
4545
from dpnp.dpnp_utils import *
46-
from dpnp.dpnp_utils.dpnp_utils_linearalgebra import (
47-
dpnp_dot,
48-
dpnp_matmul,
49-
dpnp_vdot,
50-
)
46+
from dpnp.dpnp_utils.dpnp_utils_linearalgebra import dpnp_dot, dpnp_matmul
5147

5248
__all__ = [
5349
"dot",
@@ -456,11 +452,13 @@ def vdot(a, b):
456452
457453
Parameters
458454
----------
459-
a : {dpnp_array, usm_ndarray}
460-
First input array. If `a` is complex the complex conjugate
461-
is taken before the calculation of the dot product.
455+
a : {dpnp_array, usm_ndarray, scalar}
456+
First input array. Both inputs `a` and `b` can not be
457+
scalars at the same time. If `a` is complex, the complex
458+
conjugate is taken before the calculation of the dot product.
462459
b : {dpnp_array, usm_ndarray, scalar}
463-
Second input array.
460+
Second input array. Both inputs `a` and `b` can not be
461+
scalars at the same time.
464462
465463
Returns
466464
-------
@@ -494,17 +492,19 @@ def vdot(a, b):
494492
495493
"""
496494

497-
dpnp.check_supported_arrays_type(a)
495+
dpnp.check_supported_arrays_type(a, scalar_type=True)
498496
dpnp.check_supported_arrays_type(b, scalar_type=True)
499497

500-
if dpnp.isscalar(b):
501-
if a.size != 1:
498+
if dpnp.isscalar(a) or dpnp.isscalar(b):
499+
if dpnp.isscalar(b) and a.size != 1:
502500
raise ValueError("The first array should be of size one.")
501+
if dpnp.isscalar(a) and b.size != 1:
502+
raise ValueError("The second array should be of size one.")
503503
# TODO: investigate usage of axpy (axpy_batch) or scal
504504
# functions from BLAS here instead of dpnp.multiply
505-
return dpnp.multiply(dpnp.conj(a), b)
505+
return dpnp.multiply(numpy.conj(a), b)
506506
elif a.ndim == 1 and b.ndim == 1:
507-
return dpnp_vdot(a, b)
507+
return dpnp_dot(a, b, out=None, conjugate=True)
508508
else:
509509
# dot product of flatten arrays
510-
return dpnp_vdot(dpnp.ravel(a), dpnp.ravel(b))
510+
return dpnp_dot(dpnp.ravel(a), dpnp.ravel(b), out=None, conjugate=True)

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 22 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from dpnp.dpnp_array import dpnp_array
3434
from dpnp.dpnp_utils import get_usm_allocations
3535

36-
__all__ = ["dpnp_dot", "dpnp_matmul", "dpnp_vdot"]
36+
__all__ = ["dpnp_dot", "dpnp_matmul"]
3737

3838

3939
def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None):
@@ -185,7 +185,7 @@ def _op_res_dtype(*arrays, dtype, casting, sycl_queue):
185185
return op_dtype, res_dtype
186186

187187

188-
def dpnp_dot(a, b, /, out=None):
188+
def dpnp_dot(a, b, /, out=None, *, conjugate=False):
189189
"""
190190
Return the dot product of two arrays.
191191
@@ -194,7 +194,9 @@ def dpnp_dot(a, b, /, out=None):
194194
`dpctl.tensor.vecdot` form the Data Parallel Control library is used,
195195
2) For real-valued floating point data types, `dot` routines from
196196
BLAS library of OneMKL are used, and 3) For complex data types,
197-
`dotu` routines from BLAS library of OneMKL are used.
197+
`dotu` or `dotc` routines from BLAS library of OneMKL are used.
198+
If `conjugate` is ``False``, `dotu` is used. Otherwise, `dotc` is used,
199+
for which the first array is conjugated before calculating the dot product.
198200
199201
"""
200202

@@ -228,13 +230,22 @@ def dpnp_dot(a, b, /, out=None):
228230
a = _copy_array(a, dep_events_list, host_tasks_list, dtype=dot_dtype)
229231
b = _copy_array(b, dep_events_list, host_tasks_list, dtype=dot_dtype)
230232
if dpnp.issubdtype(res_dtype, dpnp.complexfloating):
231-
ht_ev, _ = bi._dotu(
232-
exec_q,
233-
dpnp.get_usm_ndarray(a),
234-
dpnp.get_usm_ndarray(b),
235-
dpnp.get_usm_ndarray(result),
236-
dep_events_list,
237-
)
233+
if conjugate:
234+
ht_ev, _ = bi._dotc(
235+
exec_q,
236+
dpnp.get_usm_ndarray(a),
237+
dpnp.get_usm_ndarray(b),
238+
dpnp.get_usm_ndarray(result),
239+
dep_events_list,
240+
)
241+
else:
242+
ht_ev, _ = bi._dotu(
243+
exec_q,
244+
dpnp.get_usm_ndarray(a),
245+
dpnp.get_usm_ndarray(b),
246+
dpnp.get_usm_ndarray(result),
247+
dep_events_list,
248+
)
238249
else:
239250
ht_ev, _ = bi._dot(
240251
exec_q,
@@ -253,7 +264,7 @@ def dpnp_dot(a, b, /, out=None):
253264
if dot_dtype != res_dtype:
254265
result = result.astype(res_dtype, copy=False)
255266

256-
# NumPy does not allow casting even if it is safe
267+
# numpy.dot does not allow casting even if it is safe
257268
return dpnp.get_result_array(result, out, casting="no")
258269

259270

@@ -447,74 +458,3 @@ def dpnp_matmul(
447458
return result
448459
else:
449460
return dpnp.get_result_array(result, out, casting=casting)
450-
451-
452-
def dpnp_vdot(a, b):
453-
"""
454-
Return the dot product of two arrays.
455-
456-
The routine that is used to perform the main calculation
457-
depends on input arrays data type: 1) For integer and boolean data types,
458-
`dpctl.tensor.vecdot` form the Data Parallel Control library is used,
459-
2) For real-valued floating point data types, `dot` routines from
460-
BLAS library of OneMKL are used, and 3) For complex data types,
461-
`dotc` routines from BLAS library of OneMKL are used.
462-
463-
"""
464-
465-
if a.size != b.size:
466-
raise ValueError(
467-
"Input arrays have a mismatch in their size. "
468-
f"(size {a.size} is different from {b.size})"
469-
)
470-
471-
res_usm_type, exec_q = get_usm_allocations([a, b])
472-
473-
# Determine the appropriate data types
474-
# casting is irrelevant here since dtype is `None`
475-
dot_dtype, res_dtype = _op_res_dtype(
476-
a, b, dtype=None, casting="no", sycl_queue=exec_q
477-
)
478-
479-
# create result array
480-
result = dpnp.empty(
481-
(),
482-
dtype=dot_dtype,
483-
usm_type=res_usm_type,
484-
sycl_queue=exec_q,
485-
)
486-
487-
# input arrays should have the proper data type
488-
dep_events_list = []
489-
host_tasks_list = []
490-
if dpnp.issubdtype(res_dtype, dpnp.inexact):
491-
# copying is needed if dtypes of input arrays are different
492-
a = _copy_array(a, dep_events_list, host_tasks_list, dtype=dot_dtype)
493-
b = _copy_array(b, dep_events_list, host_tasks_list, dtype=dot_dtype)
494-
if dpnp.issubdtype(res_dtype, dpnp.complexfloating):
495-
ht_ev, _ = bi._dotc(
496-
exec_q,
497-
dpnp.get_usm_ndarray(a),
498-
dpnp.get_usm_ndarray(b),
499-
dpnp.get_usm_ndarray(result),
500-
dep_events_list,
501-
)
502-
else:
503-
ht_ev, _ = bi._dot(
504-
exec_q,
505-
dpnp.get_usm_ndarray(a),
506-
dpnp.get_usm_ndarray(b),
507-
dpnp.get_usm_ndarray(result),
508-
dep_events_list,
509-
)
510-
host_tasks_list.append(ht_ev)
511-
dpctl.SyclEvent.wait_for(host_tasks_list)
512-
else:
513-
dpt_a = dpnp.get_usm_ndarray(a)
514-
dpt_b = dpnp.get_usm_ndarray(b)
515-
result = dpnp_array._create_from_usm_ndarray(dpt.vecdot(dpt_a, dpt_b))
516-
517-
if dot_dtype != res_dtype:
518-
result = result.astype(res_dtype, copy=False)
519-
520-
return result

tests/test_dot.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,10 @@ def test_vdot_scalar(self, dtype):
384384
expected = numpy.vdot(a, b)
385385
assert_allclose(result, expected)
386386

387+
result = dpnp.vdot(b, ia)
388+
expected = numpy.vdot(b, a)
389+
assert_allclose(result, expected)
390+
387391
@pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True))
388392
@pytest.mark.parametrize(
389393
"array_info",
@@ -505,3 +509,7 @@ def test_vdot_error(self):
505509
# The first array should be of size one
506510
with pytest.raises(ValueError):
507511
dpnp.vdot(a, b)
512+
513+
# The second array should be of size one
514+
with pytest.raises(ValueError):
515+
dpnp.vdot(b, a)

0 commit comments

Comments
 (0)