33
33
from dpnp .dpnp_array import dpnp_array
34
34
from dpnp .dpnp_utils import get_usm_allocations
35
35
36
- __all__ = ["dpnp_dot" , "dpnp_matmul" , "dpnp_vdot" ]
36
+ __all__ = ["dpnp_dot" , "dpnp_matmul" ]
37
37
38
38
39
39
def _copy_array (x , dep_events , host_events , contig_copy = False , dtype = None ):
@@ -185,7 +185,7 @@ def _op_res_dtype(*arrays, dtype, casting, sycl_queue):
185
185
return op_dtype , res_dtype
186
186
187
187
188
- def dpnp_dot (a , b , / , out = None ):
188
+ def dpnp_dot (a , b , / , out = None , * , conjugate = False ):
189
189
"""
190
190
Return the dot product of two arrays.
191
191
@@ -194,7 +194,9 @@ def dpnp_dot(a, b, /, out=None):
194
194
`dpctl.tensor.vecdot` form the Data Parallel Control library is used,
195
195
2) For real-valued floating point data types, `dot` routines from
196
196
BLAS library of OneMKL are used, and 3) For complex data types,
197
- `dotu` routines from BLAS library of OneMKL are used.
197
+ `dotu` or `dotc` routines from BLAS library of OneMKL are used.
198
+ If `conjugate` is ``False``, `dotu` is used. Otherwise, `dotc` is used,
199
+ for which the first array is conjugated before calculating the dot product.
198
200
199
201
"""
200
202
@@ -228,13 +230,22 @@ def dpnp_dot(a, b, /, out=None):
228
230
a = _copy_array (a , dep_events_list , host_tasks_list , dtype = dot_dtype )
229
231
b = _copy_array (b , dep_events_list , host_tasks_list , dtype = dot_dtype )
230
232
if dpnp .issubdtype (res_dtype , dpnp .complexfloating ):
231
- ht_ev , _ = bi ._dotu (
232
- exec_q ,
233
- dpnp .get_usm_ndarray (a ),
234
- dpnp .get_usm_ndarray (b ),
235
- dpnp .get_usm_ndarray (result ),
236
- dep_events_list ,
237
- )
233
+ if conjugate :
234
+ ht_ev , _ = bi ._dotc (
235
+ exec_q ,
236
+ dpnp .get_usm_ndarray (a ),
237
+ dpnp .get_usm_ndarray (b ),
238
+ dpnp .get_usm_ndarray (result ),
239
+ dep_events_list ,
240
+ )
241
+ else :
242
+ ht_ev , _ = bi ._dotu (
243
+ exec_q ,
244
+ dpnp .get_usm_ndarray (a ),
245
+ dpnp .get_usm_ndarray (b ),
246
+ dpnp .get_usm_ndarray (result ),
247
+ dep_events_list ,
248
+ )
238
249
else :
239
250
ht_ev , _ = bi ._dot (
240
251
exec_q ,
@@ -253,7 +264,7 @@ def dpnp_dot(a, b, /, out=None):
253
264
if dot_dtype != res_dtype :
254
265
result = result .astype (res_dtype , copy = False )
255
266
256
- # NumPy does not allow casting even if it is safe
267
+ # numpy.dot does not allow casting even if it is safe
257
268
return dpnp .get_result_array (result , out , casting = "no" )
258
269
259
270
@@ -447,74 +458,3 @@ def dpnp_matmul(
447
458
return result
448
459
else :
449
460
return dpnp .get_result_array (result , out , casting = casting )
450
-
451
-
452
- def dpnp_vdot (a , b ):
453
- """
454
- Return the dot product of two arrays.
455
-
456
- The routine that is used to perform the main calculation
457
- depends on input arrays data type: 1) For integer and boolean data types,
458
- `dpctl.tensor.vecdot` form the Data Parallel Control library is used,
459
- 2) For real-valued floating point data types, `dot` routines from
460
- BLAS library of OneMKL are used, and 3) For complex data types,
461
- `dotc` routines from BLAS library of OneMKL are used.
462
-
463
- """
464
-
465
- if a .size != b .size :
466
- raise ValueError (
467
- "Input arrays have a mismatch in their size. "
468
- f"(size { a .size } is different from { b .size } )"
469
- )
470
-
471
- res_usm_type , exec_q = get_usm_allocations ([a , b ])
472
-
473
- # Determine the appropriate data types
474
- # casting is irrelevant here since dtype is `None`
475
- dot_dtype , res_dtype = _op_res_dtype (
476
- a , b , dtype = None , casting = "no" , sycl_queue = exec_q
477
- )
478
-
479
- # create result array
480
- result = dpnp .empty (
481
- (),
482
- dtype = dot_dtype ,
483
- usm_type = res_usm_type ,
484
- sycl_queue = exec_q ,
485
- )
486
-
487
- # input arrays should have the proper data type
488
- dep_events_list = []
489
- host_tasks_list = []
490
- if dpnp .issubdtype (res_dtype , dpnp .inexact ):
491
- # copying is needed if dtypes of input arrays are different
492
- a = _copy_array (a , dep_events_list , host_tasks_list , dtype = dot_dtype )
493
- b = _copy_array (b , dep_events_list , host_tasks_list , dtype = dot_dtype )
494
- if dpnp .issubdtype (res_dtype , dpnp .complexfloating ):
495
- ht_ev , _ = bi ._dotc (
496
- exec_q ,
497
- dpnp .get_usm_ndarray (a ),
498
- dpnp .get_usm_ndarray (b ),
499
- dpnp .get_usm_ndarray (result ),
500
- dep_events_list ,
501
- )
502
- else :
503
- ht_ev , _ = bi ._dot (
504
- exec_q ,
505
- dpnp .get_usm_ndarray (a ),
506
- dpnp .get_usm_ndarray (b ),
507
- dpnp .get_usm_ndarray (result ),
508
- dep_events_list ,
509
- )
510
- host_tasks_list .append (ht_ev )
511
- dpctl .SyclEvent .wait_for (host_tasks_list )
512
- else :
513
- dpt_a = dpnp .get_usm_ndarray (a )
514
- dpt_b = dpnp .get_usm_ndarray (b )
515
- result = dpnp_array ._create_from_usm_ndarray (dpt .vecdot (dpt_a , dpt_b ))
516
-
517
- if dot_dtype != res_dtype :
518
- result = result .astype (res_dtype , copy = False )
519
-
520
- return result
0 commit comments