Skip to content

Commit bafe125

Browse files
Fill exception messages
1 parent c139d2c commit bafe125

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

dpctl/tensor/_indexing_functions.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,12 +470,17 @@ def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"):
470470
by the value of ``mode`` keyword.
471471
"""
472472
if not isinstance(x, dpt.usm_ndarray):
473-
raise TypeError
473+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
474474
if not isinstance(indices, dpt.usm_ndarray):
475-
raise TypeError
475+
raise TypeError(
476+
f"Expected dpctl.tensor.usm_ndarray, got {type(indices)}"
477+
)
476478
x_nd = x.ndim
477479
if x_nd != indices.ndim:
478-
raise ValueError
480+
raise ValueError(
481+
"Number of dimensions in the first and the second "
482+
"argument arrays must be equal"
483+
)
479484
pp = normalize_axis_index(operator.index(axis), x_nd)
480485
out_usm_type = dpctl.utils.get_coerced_usm_type(
481486
(x.usm_type, indices.usm_type)

0 commit comments

Comments
 (0)