@@ -209,6 +209,31 @@ def _op_res_dtype(*arrays, dtype, casting, sycl_queue):
209
209
return op_dtype , res_dtype
210
210
211
211
212
+ def _shape_error (a , b , core_dim , err_msg ):
213
+ if err_msg == 0 :
214
+ raise ValueError (
215
+ "Input arrays have a mismatch in their core dimensions. "
216
+ "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
217
+ f"(size { a } is different from { b } )"
218
+ )
219
+ elif err_msg == 1 :
220
+ raise ValueError (
221
+ f"Output array has a mismatch in its core dimension { core_dim } . "
222
+ "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
223
+ f"(size { a } is different from { b } )"
224
+ )
225
+ elif err_msg == 2 :
226
+ raise ValueError (
227
+ "Input arrays could not be broadcast together with remapped shapes, "
228
+ f"{ a } is different from { b } ."
229
+ )
230
+ elif err_msg == 3 :
231
+ raise ValueError (
232
+ "Output array could not be broadcast to input arrays with remapped shapes, "
233
+ f"{ a } is different from { b } ."
234
+ )
235
+
236
+
212
237
def _standardize_strides (strides , inherently_2D , shape , ndim ):
213
238
"""
214
239
Standardizing the strides.
@@ -436,42 +461,22 @@ def dpnp_matmul(
436
461
x1_shape = x1 .shape
437
462
x2_shape = x2 .shape
438
463
if x1_shape [- 1 ] != x2_shape [- 2 ]:
439
- raise ValueError (
440
- "Input arrays have a mismatch in their core dimensions. "
441
- "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
442
- f"(size { x1_shape [- 1 ]} is different from { x2_shape [- 2 ]} )"
443
- )
464
+ _shape_error (x1_shape [- 1 ], x2_shape [- 2 ], None , 0 )
444
465
445
466
if out is not None :
446
467
out_shape = out .shape
447
468
if not appended_axes :
448
469
if out_shape [- 2 ] != x1_shape [- 2 ]:
449
- raise ValueError (
450
- "Output array has a mismatch in its core dimension 0. "
451
- "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
452
- f"(size { out_shape [- 2 ]} is different from { x1_shape [- 2 ]} )"
453
- )
470
+ _shape_error (out_shape [- 2 ], x1_shape [- 2 ], 0 , 1 )
454
471
if out_shape [- 1 ] != x2_shape [- 1 ]:
455
- raise ValueError (
456
- "Output array has a mismatch in its core dimension 1. "
457
- "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
458
- f"(size { out_shape [- 1 ]} is different from { x2_shape [- 1 ]} )"
459
- )
472
+ _shape_error (out_shape [- 1 ], x2_shape [- 1 ], 1 , 1 )
460
473
elif len (appended_axes ) == 1 :
461
474
if appended_axes [0 ] == - 1 :
462
475
if out_shape [- 1 ] != x1_shape [- 2 ]:
463
- raise ValueError (
464
- "Output array has a mismatch in its core dimension 0. "
465
- "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
466
- f"(size { out_shape [- 1 ]} is different from { x1_shape [- 2 ]} )"
467
- )
476
+ _shape_error (out_shape [- 1 ], x1_shape [- 2 ], 0 , 1 )
468
477
elif appended_axes [0 ] == - 2 :
469
478
if out_shape [- 1 ] != x2_shape [- 1 ]:
470
- raise ValueError (
471
- "Output array has a mismatch in its core dimension 0. "
472
- "The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
473
- f"(size { out_shape [- 1 ]} is different from { x2_shape [- 1 ]} )"
474
- )
479
+ _shape_error (out_shape [- 1 ], x2_shape [- 1 ], 0 , 1 )
475
480
476
481
# Determine the appropriate data types
477
482
gemm_dtype , res_dtype = _op_res_dtype (
@@ -516,26 +521,18 @@ def dpnp_matmul(
516
521
if not x2_is_2D :
517
522
x2 = dpnp .repeat (x2 , x1_shape [i ], axis = i )
518
523
else :
519
- raise ValueError (
520
- "Input arrays could not be broadcast together with remapped shapes, "
521
- f"{ x1_shape [:- 2 ]} is different from { x2_shape [:- 2 ]} ."
522
- )
524
+ _shape_error (x1_shape [:- 2 ], x2_shape [:- 2 ], None , 2 )
523
525
524
526
x1_shape = x1 .shape
525
527
x2_shape = x2 .shape
526
528
if out is not None :
527
529
for i in range (x1_ndim - 2 ):
528
530
if tmp_shape [i ] != out_shape [i ]:
529
531
if not appended_axes :
530
- raise ValueError (
531
- "Output array could not be broadcast together with remapped shapes, "
532
- f"{ tmp_shape [:- 2 ]} is different from { out_shape [:- 2 ]} ."
533
- )
532
+ _shape_error (tuple (tmp_shape ), out_shape [:- 2 ], None , 3 )
534
533
elif len (appended_axes ) == 1 :
535
- raise ValueError (
536
- "Output array could not be broadcast together with remapped shapes, "
537
- f"{ tmp_shape [:- 2 ]} is different from { out_shape [:- 1 ]} ."
538
- )
534
+ _shape_error (tuple (tmp_shape ), out_shape [:- 1 ], None , 3 )
535
+
539
536
res_shape = tuple (tmp_shape ) + (x1_shape [- 2 ], x2_shape [- 1 ])
540
537
541
538
# handling a special case to provide a similar result to NumPy
0 commit comments