27
27
import dpctl .tensor as dpt
28
28
import dpctl .tensor ._tensor_impl as ti
29
29
import numpy
30
+ from numpy .core .numeric import normalize_axis_tuple
30
31
31
32
import dpnp
32
33
import dpnp .backend .extensions .blas ._blas_impl as bi
@@ -43,7 +44,7 @@ def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
43
44
If `out` is not ``None`` and its features match the specified `shape`, `dtype,
44
45
`usm_type`, and `sycl_queue` and it is C-contiguous or F-contiguous and
45
46
does not have any memory overlap with `x1` and `x2`, `out` itself is returned.
46
- If these conditions are not statisfied , an empty array is returned with the
47
+ If these conditions are not satisfied , an empty array is returned with the
47
48
specified `shape`, `dtype, `usm_type`, and `sycl_queue`.
48
49
"""
49
50
@@ -116,21 +117,9 @@ def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list):
116
117
x2_strides = x2 .strides
117
118
res_strides = res .strides
118
119
119
- # when shape along any particular dimension is 1,
120
- # the stride along that dimension is not a
121
- # meaningful number and is undefined. Here, we
122
- # standardizing strides before continuing,
123
- # setting stride to 0 if the shape along that axis is <=1
124
- if x1_is_2D :
125
- x1_strides = tuple (
126
- str_i if sh_i > 1 else 0
127
- for sh_i , str_i in zip (x1 .shape , x1_strides )
128
- )
129
- if x2_is_2D :
130
- x2_strides = tuple (
131
- str_i if sh_i > 1 else 0
132
- for sh_i , str_i in zip (x2 .shape , x2_strides )
133
- )
120
+ # need to standardize to use in ti._contract_iter2
121
+ x1_strides = _standardize_strides (x1_strides , x1_is_2D , x1 .shape , x1 .ndim )
122
+ x2_strides = _standardize_strides (x2_strides , x2_is_2D , x2 .shape , x2 .ndim )
134
123
135
124
batch_size = res .shape [:- 2 ][0 ]
136
125
stridea = x1_strides [0 ]
@@ -220,6 +209,92 @@ def _op_res_dtype(*arrays, dtype, casting, sycl_queue):
220
209
return op_dtype , res_dtype
221
210
222
211
212
+ def _standardize_strides (strides , inherently_2D , shape , ndim ):
213
+ """
214
+ Standardizing the strides.
215
+
216
+ When shape of an array along any particular dimension is 1, the stride
217
+ along that dimension is undefined. This functions standardize the strides
218
+ in the following way:
219
+ For N-D arrays that are inherently 2D (all dimesnsion are one except for two of them),
220
+ we use zero as the stride for dimensions equal one.
221
+ For other N-D arrays, the non-zero value of strides is calculated and used.
222
+
223
+ """
224
+
225
+ if inherently_2D :
226
+ stndrd_strides = tuple (
227
+ str_i if sh_i > 1 else 0 for sh_i , str_i in zip (shape , strides )
228
+ )
229
+ else :
230
+ stndrd_strides = [
231
+ numpy .prod (shape [i + 1 :]) if strides [i ] == 0 else strides [i ]
232
+ for i in range (ndim - 1 )
233
+ ]
234
+ # last dimension
235
+ stndrd_strides .append (
236
+ 1 if strides [ndim - 1 ] == 0 else strides [ndim - 1 ]
237
+ )
238
+ stndrd_strides = tuple (stndrd_strides )
239
+
240
+ return stndrd_strides
241
+
242
+
243
+ def _validate_axes (x1 , x2 , axes ):
244
+ """Check axes is valid for matmul function."""
245
+
246
+ def _validate_internal (axes , i , ndim ):
247
+ if ndim == 1 :
248
+ iter = 1
249
+ if isinstance (axes , int ):
250
+ axes = (axes ,)
251
+ elif not isinstance (axes , tuple ):
252
+ raise TypeError (
253
+ f"Axes item { i } : { type (axes )} object cannot be interpreted as an integer."
254
+ )
255
+
256
+ if len (axes ) != 1 :
257
+ raise ValueError (
258
+ f"Axes item { i } should be a tuple with a single element, or an integer."
259
+ )
260
+ else :
261
+ iter = 2
262
+ if not isinstance (axes , tuple ):
263
+ raise TypeError (f"Axes item { i } should be a tuple." )
264
+ if len (axes ) != 2 :
265
+ raise ValueError (
266
+ f"Axes item { i } should be a tuple with 2 elements."
267
+ )
268
+
269
+ for j in range (iter ):
270
+ if not isinstance (axes [j ], int ):
271
+ raise TypeError (
272
+ f"Axes item { i } : { type (axes [j ])} object cannot be interpreted as an integer."
273
+ )
274
+ return axes
275
+
276
+ if not isinstance (axes , list ):
277
+ raise TypeError ("Axes should be a list." )
278
+ else :
279
+ if len (axes ) != 3 :
280
+ raise ValueError (
281
+ "Axes should be a list of three tuples for inputs and output."
282
+ )
283
+
284
+ axes [0 ] = _validate_internal (axes [0 ], 0 , x1 .ndim )
285
+ axes [1 ] = _validate_internal (axes [1 ], 1 , x2 .ndim )
286
+
287
+ if x1 .ndim == 1 and x2 .ndim == 1 :
288
+ if axes [2 ] != ():
289
+ raise TypeError ("Axes item 2 should be an empty tuple." )
290
+ elif x1 .ndim == 1 or x2 .ndim == 1 :
291
+ axes [2 ] = _validate_internal (axes [2 ], 2 , 1 )
292
+ else :
293
+ axes [2 ] = _validate_internal (axes [2 ], 2 , 2 )
294
+
295
+ return axes
296
+
297
+
223
298
def dpnp_dot (a , b , / , out = None , * , conjugate = False ):
224
299
"""
225
300
Return the dot product of two arrays.
@@ -302,6 +377,7 @@ def dpnp_matmul(
302
377
casting = "same_kind" ,
303
378
order = "K" ,
304
379
dtype = None ,
380
+ axes = None ,
305
381
):
306
382
"""
307
383
Return the matrix product of two arrays.
@@ -327,6 +403,22 @@ def dpnp_matmul(
327
403
328
404
res_usm_type , exec_q = get_usm_allocations ([x1 , x2 ])
329
405
406
+ if axes is not None :
407
+ axes = _validate_axes (x1 , x2 , axes )
408
+
409
+ axes_x1 , axes_x2 , axes_res = axes
410
+ axes_x1 = normalize_axis_tuple (axes_x1 , x1 .ndim , "axis" )
411
+ axes_x2 = normalize_axis_tuple (axes_x2 , x2 .ndim , "axis" )
412
+ # Move the axes that are going to be used in matrix product,
413
+ # to the end of "x1" and "x2"
414
+ x1 = dpnp .moveaxis (x1 , axes_x1 , (- 2 , - 1 )) if x1 .ndim != 1 else x1
415
+ x2 = dpnp .moveaxis (x2 , axes_x2 , (- 2 , - 1 )) if x2 .ndim != 1 else x2
416
+ out_orig = out
417
+ if out is not None :
418
+ dpnp .check_supported_arrays_type (out )
419
+ # out that is passed to the backend should have the correct shape
420
+ out = dpnp .moveaxis (out , axes_res , (- 2 , - 1 ))
421
+
330
422
appended_axes = []
331
423
if x1_ndim == 1 :
332
424
x1 = x1 [dpnp .newaxis , :]
@@ -397,9 +489,15 @@ def dpnp_matmul(
397
489
x2_shape = x2 .shape
398
490
res_shape = tuple (tmp_shape ) + (x1_shape [- 2 ], x2_shape [- 1 ])
399
491
492
+ # handling a special case to provide a similar result to NumPy
493
+ if out is not None and x1 .shape == (1 , 0 ) and x2 .shape == (0 , 1 ):
494
+ res_shape = (0 ,)
495
+ appended_axes = []
496
+
400
497
result = _create_result_array (
401
498
x1 , x2 , out , res_shape , gemm_dtype , res_usm_type , exec_q
402
499
)
500
+
403
501
# calculate result
404
502
if result .size == 0 :
405
503
pass
@@ -471,12 +569,25 @@ def dpnp_matmul(
471
569
472
570
if gemm_dtype != res_dtype :
473
571
result = dpnp .astype (result , res_dtype , copy = False )
572
+
474
573
if out is None :
574
+ if axes is not None :
575
+ # Move the result to the appropriate axes of out array
576
+ if len (axes_res ) == 2 :
577
+ result = dpnp .moveaxis (result , (- 2 , - 1 ), axes_res )
578
+ elif len (axes_res ) == 1 :
579
+ result = dpnp .moveaxis (result , (- 1 ,), axes_res )
580
+ return result
475
581
# If `order` was not passed as default
476
582
# we need to update it to match the passed `order`.
477
- if order not in ["k" , "K" ]:
583
+ elif order not in ["k" , "K" ]:
478
584
return dpnp .array (result , copy = False , order = order )
479
585
else :
480
586
return result
481
587
else :
482
- return dpnp .get_result_array (result , out , casting = casting )
588
+ result = dpnp .get_result_array (result , out , casting = casting )
589
+ if axes is not None :
590
+ if out is result :
591
+ # out and out_orig contain the same data but they have different shape
592
+ return out_orig
593
+ return result
0 commit comments