Skip to content

Commit 5060adc

Browse files
committed
use a function for error msg
1 parent af0ed2b commit 5060adc

File tree

1 file changed

+34
-37
lines changed

1 file changed

+34
-37
lines changed

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,31 @@ def _op_res_dtype(*arrays, dtype, casting, sycl_queue):
209209
return op_dtype, res_dtype
210210

211211

212+
def _shape_error(a, b, core_dim, err_msg):
213+
if err_msg == 0:
214+
raise ValueError(
215+
"Input arrays have a mismatch in their core dimensions. "
216+
"The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
217+
f"(size {a} is different from {b})"
218+
)
219+
elif err_msg == 1:
220+
raise ValueError(
221+
f"Output array has a mismatch in its core dimension {core_dim}. "
222+
"The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
223+
f"(size {a} is different from {b})"
224+
)
225+
elif err_msg == 2:
226+
raise ValueError(
227+
"Input arrays could not be broadcast together with remapped shapes, "
228+
f"{a} is different from {b}."
229+
)
230+
elif err_msg == 3:
231+
raise ValueError(
232+
"Output array could not be broadcast to input arrays with remapped shapes, "
233+
f"{a} is different from {b}."
234+
)
235+
236+
212237
def _standardize_strides(strides, inherently_2D, shape, ndim):
213238
"""
214239
Standardizing the strides.
@@ -436,42 +461,22 @@ def dpnp_matmul(
436461
x1_shape = x1.shape
437462
x2_shape = x2.shape
438463
if x1_shape[-1] != x2_shape[-2]:
439-
raise ValueError(
440-
"Input arrays have a mismatch in their core dimensions. "
441-
"The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
442-
f"(size {x1_shape[-1]} is different from {x2_shape[-2]})"
443-
)
464+
_shape_error(x1_shape[-1], x2_shape[-2], None, 0)
444465

445466
if out is not None:
446467
out_shape = out.shape
447468
if not appended_axes:
448469
if out_shape[-2] != x1_shape[-2]:
449-
raise ValueError(
450-
"Output array has a mismatch in its core dimension 0. "
451-
"The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
452-
f"(size {out_shape[-2]} is different from {x1_shape[-2]})"
453-
)
470+
_shape_error(out_shape[-2], x1_shape[-2], 0, 1)
454471
if out_shape[-1] != x2_shape[-1]:
455-
raise ValueError(
456-
"Output array has a mismatch in its core dimension 1. "
457-
"The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
458-
f"(size {out_shape[-1]} is different from {x2_shape[-1]})"
459-
)
472+
_shape_error(out_shape[-1], x2_shape[-1], 1, 1)
460473
elif len(appended_axes) == 1:
461474
if appended_axes[0] == -1:
462475
if out_shape[-1] != x1_shape[-2]:
463-
raise ValueError(
464-
"Output array has a mismatch in its core dimension 0. "
465-
"The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
466-
f"(size {out_shape[-1]} is different from {x1_shape[-2]})"
467-
)
476+
_shape_error(out_shape[-1], x1_shape[-2], 0, 1)
468477
elif appended_axes[0] == -2:
469478
if out_shape[-1] != x2_shape[-1]:
470-
raise ValueError(
471-
"Output array has a mismatch in its core dimension 0. "
472-
"The core dimensions should follow this signature: (n?,k),(k,m?)->(n?,m?) "
473-
f"(size {out_shape[-1]} is different from {x2_shape[-1]})"
474-
)
479+
_shape_error(out_shape[-1], x2_shape[-1], 0, 1)
475480

476481
# Determine the appropriate data types
477482
gemm_dtype, res_dtype = _op_res_dtype(
@@ -516,26 +521,18 @@ def dpnp_matmul(
516521
if not x2_is_2D:
517522
x2 = dpnp.repeat(x2, x1_shape[i], axis=i)
518523
else:
519-
raise ValueError(
520-
"Input arrays could not be broadcast together with remapped shapes, "
521-
f"{x1_shape[:-2]} is different from {x2_shape[:-2]}."
522-
)
524+
_shape_error(x1_shape[:-2], x2_shape[:-2], None, 2)
523525

524526
x1_shape = x1.shape
525527
x2_shape = x2.shape
526528
if out is not None:
527529
for i in range(x1_ndim - 2):
528530
if tmp_shape[i] != out_shape[i]:
529531
if not appended_axes:
530-
raise ValueError(
531-
"Output array could not be broadcast together with remapped shapes, "
532-
f"{tmp_shape[:-2]} is different from {out_shape[:-2]}."
533-
)
532+
_shape_error(tuple(tmp_shape), out_shape[:-2], None, 3)
534533
elif len(appended_axes) == 1:
535-
raise ValueError(
536-
"Output array could not be broadcast together with remapped shapes, "
537-
f"{tmp_shape[:-2]} is different from {out_shape[:-1]}."
538-
)
534+
_shape_error(tuple(tmp_shape), out_shape[:-1], None, 3)
535+
539536
res_shape = tuple(tmp_shape) + (x1_shape[-2], x2_shape[-1])
540537

541538
# handling a special case to provide a similar result to NumPy

0 commit comments

Comments
 (0)