@@ -134,11 +134,12 @@ def _create_result_array(
134
134
"""
135
135
Create the result array.
136
136
137
- If `out` is not ``None`` and its features match the specified `shape`, `dtype,
138
- `usm_type `, and `sycl_queue` and it is C- contiguous or F-contiguous and
139
- does not have any memory overlap with `x1` and `x2`, `out` itself is returned.
137
+ If `out` is not ``None`` and its shape and dtype match the desired `shape`
138
+ and `dtype `, and its 2-D base is contiguous and it does not have any memory
139
+ overlap with `x1` and `x2`, `out` itself is returned.
140
140
If these conditions are not satisfied, an empty array is returned with the
141
141
specified `shape`, `dtype, `usm_type`, and `sycl_queue`.
142
+
142
143
"""
143
144
144
145
if out is not None :
@@ -150,7 +151,6 @@ def _create_result_array(
150
151
if (
151
152
out .dtype == dtype
152
153
and out .shape == shape
153
- and out .usm_type == usm_type
154
154
and contig_flag
155
155
and not ti ._array_overlap (x1_usm , out_usm )
156
156
and not ti ._array_overlap (x2_usm , out_usm )
@@ -325,10 +325,13 @@ def _get_result_shape(x1, x2, out, np_flag):
325
325
326
326
def _gemm_batch_matmul (exec_q , x1 , x2 , res ):
327
327
# arrays here are already at least 3D, make them 3D
328
- x1 = dpnp .reshape (x1 , (- 1 , x1 .shape [- 2 ], x1 .shape [- 1 ]))
329
- x2 = dpnp .reshape (x2 , (- 1 , x2 .shape [- 2 ], x2 .shape [- 1 ]))
328
+ x1_shape = x1 .shape
329
+ x2_shape = x2 .shape
330
+ x1 = dpnp .reshape (x1 , (- 1 , x1_shape [- 2 ], x1_shape [- 1 ]))
331
+ x2 = dpnp .reshape (x2 , (- 1 , x2_shape [- 2 ], x2_shape [- 1 ]))
330
332
orig_shape = res .shape
331
- res = dpnp .reshape (res , (- 1 , res .shape [- 2 ], res .shape [- 1 ]))
333
+ res = dpnp .reshape (res , (- 1 , orig_shape [- 2 ], orig_shape [- 1 ]))
334
+ res_shape = res .shape
332
335
333
336
# gemm_batch does not handle negative strides, make a copy if needed
334
337
x1 = _copy_array (x1 , copy_flag = x1 .strides [0 ] < 0 )
@@ -338,16 +341,16 @@ def _gemm_batch_matmul(exec_q, x1, x2, res):
338
341
_manager = dpu .SequentialOrderManager [exec_q ]
339
342
340
343
# onemkl::blas::gemm_bacth throws an exception (Provided range is out
341
- # of integer limits) if the batch_size is too large (>=4096*4096) , so
342
- # we need to split the batch into smaller chunks
343
- chunk = 2048 * 2048
344
- batch_size = res . shape [0 ]
344
+ # of integer limits) if the batch_size is too large, so we need to
345
+ # split the batch into smaller chunks, the size depnends on device
346
+ chunk = 4096 * 4096 - 2
347
+ batch_size = res_shape [0 ]
345
348
for i in range (0 , batch_size , chunk ):
346
- if x1 . shape [0 ] == 1 :
349
+ if x1_shape [0 ] == 1 :
347
350
# x1 is repeatedly multiplied with each matrix in x2
348
351
x1_usm = dpnp .get_usm_ndarray (x1 )
349
352
x2_usm = dpnp .get_usm_ndarray (x2 [i : i + chunk , ...])
350
- elif x2 . shape [0 ] == 1 :
353
+ elif x2_shape [0 ] == 1 :
351
354
x1_usm = dpnp .get_usm_ndarray (x1 [i : i + chunk , ...])
352
355
x2_usm = dpnp .get_usm_ndarray (x2 )
353
356
else :
@@ -364,25 +367,36 @@ def _gemm_batch_matmul(exec_q, x1, x2, res):
364
367
)
365
368
_manager .add_event_pair (ht_ev , blas_ev )
366
369
367
- res_shape = res .shape
368
370
_ , res_is_c_contig , res_is_f_contig = _define_contig_flag (res )
369
371
if row_major :
370
372
if res_is_f_contig :
371
- res = dpnp .reshape (
372
- dpnp .ravel (res , order = "F" ),
373
- (res_shape [1 ], res_shape [2 ], batch_size ),
374
- ).transpose (2 , 0 , 1 )
373
+ # Considering the multiplication for one of the batches,
374
+ # we have result[0, 1] = a[0, :]*b[1, :]. In row_major mode,
375
+ # it is assumed result array is c-contiguous, i.e. the value of
376
+ # result[0, 1] is has the second place memory.
377
+ # however, the result array is batches of 2D f-contiguous array,
378
+ # i.e. the second place of memory points out to res[1, 0].
379
+ # So, we need to read data of each 2D array in the batch in
380
+ # "F" order and write it in "C" order
381
+ res = (
382
+ res .ravel (order = "F" )
383
+ .reshape (res_shape [1 ], res_shape [2 ], batch_size )
384
+ .transpose (2 , 0 , 1 )
385
+ )
375
386
else :
376
387
if res_is_c_contig :
377
- res = dpnp .reshape (
378
- dpnp .ravel (res , order = "C" ),
379
- (batch_size , res_shape [2 ], res_shape [1 ]),
380
- ).transpose (0 , 2 , 1 )
388
+ # read data of each 2D array in the batch in "C" order and
389
+ # write it in "F" order
390
+ res = (
391
+ res .ravel (order = "C" )
392
+ .reshape (batch_size , res_shape [2 ], res_shape [1 ])
393
+ .transpose (0 , 2 , 1 )
394
+ )
381
395
382
396
if res_shape != orig_shape :
383
397
res = res .reshape (orig_shape )
384
398
385
- return dpnp . ascontiguousarray ( res )
399
+ return res
386
400
387
401
388
402
def _gemm_matmul (exec_q , x1 , x2 , res ):
@@ -400,13 +414,13 @@ def _gemm_matmul(exec_q, x1, x2, res):
400
414
if row_major :
401
415
if res .flags .f_contiguous is True :
402
416
# read data in "F" order and write it in "C" order
403
- res = dpnp .reshape ( dpnp . ravel (res , order = "F" ), res .shape , order = "C" )
417
+ res = dpnp .ravel (res , order = "F" ). reshape ( res .shape , order = "C" )
404
418
else :
405
419
if res .flags .c_contiguous is True :
406
420
# read data in "C" order and write it in "F" order
407
- res = dpnp .reshape ( dpnp . ravel (res , order = "C" ), res .shape , order = "F" )
421
+ res = dpnp .ravel (res , order = "C" ). reshape ( res .shape , order = "F" )
408
422
409
- return dpnp . ascontiguousarray ( res )
423
+ return res
410
424
411
425
412
426
def _shape_error (a , b , core_dim , err_msg ):
@@ -767,9 +781,9 @@ def dpnp_matmul(
767
781
call_flag = "multiply"
768
782
elif x1_is_1D and x2_is_1D :
769
783
call_flag = "dot"
770
- x1 = dpnp . reshape ( x1 , x1_shape [ - 1 ])
771
- if x2_ndim != 1 :
772
- x2 = dpnp .reshape (x2 , x2_shape [ - 2 ] )
784
+ # arrays are inehrently 1D, make them 1D
785
+ x1 = dpnp . ravel ( x1 )
786
+ x2 = dpnp .ravel (x2 )
773
787
elif x1_base_is_1D and x2_base_is_1D :
774
788
# TODO: implement a batch version of dot to use it here
775
789
call_flag = "gemm_batch"
@@ -912,12 +926,11 @@ def dpnp_matmul(
912
926
# we need to update it to match the passed `order`.
913
927
if order not in ["k" , "K" ]:
914
928
return dpnp .array (result , copy = False , order = order )
915
- return result
929
+ # dpnp.ascontiguousarray changes 0-D array to 1-D array
930
+ if result .ndim == 0 :
931
+ return result
932
+ return dpnp .ascontiguousarray (result )
916
933
917
- # TODO: There is opportunity to improve performance when out keyword is
918
- # present. For some cases, out is NOT result but they have the same base
919
- # (They are views of the same data). In this case, we can avoid copyign
920
- # result to out.
921
934
result = dpnp .get_result_array (result , out , casting = casting )
922
935
if axes is not None and out is result :
923
936
# out and out_orig contain the same data but they have different shape
0 commit comments