Skip to content

Commit 95e52ff

Browse files
Applied PR suggestions
1 parent ddb68e8 commit 95e52ff

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

dpctl/tensor/_ctors.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,22 +1212,30 @@ def meshgrid(*arrays, indexing="xy"):
12121212
If vectors are not of the same data type,
12131213
or are not one-dimensional, raises `ValueError.`
12141214
indexing: Cartesian (`xy`) or matrix (`ij`) indexing of output.
1215-
For a set of `n` vectors with lengths X0, X1, X2, ...
1215+
For a set of `n` vectors with lengths N0, N1, N2, ...
12161216
Cartesian indexing results in arrays of shape
1217-
(X1, X0, X2, ...)
1217+
(N1, N0, N2, ...)
12181218
matrix indexing results in arrays of shape
1219-
(X0, X1, X2, ...)
1219+
(n0, N1, N2, ...)
12201220
Default: `xy`.
12211221
"""
1222+
ref_dt = None
1223+
ref_unset = True
12221224
for array in arrays:
12231225
if not isinstance(array, dpt.usm_ndarray):
12241226
raise TypeError(
12251227
f"Expected instance of dpt.usm_ndarray, got {type(array)}."
12261228
)
12271229
if array.ndim != 1:
12281230
raise ValueError("All arrays must be one-dimensional.")
1229-
if len(set([array.dtype for array in arrays])) > 1:
1230-
raise ValueError("All arrays must be of the same numeric data type.")
1231+
if ref_unset:
1232+
ref_unset = False
1233+
ref_dt = array.dtype
1234+
else:
1235+
if not ref_dt == array.dtype:
1236+
raise ValueError(
1237+
"All arrays must be of the same numeric data type."
1238+
)
12311239
if indexing not in ["xy", "ij"]:
12321240
raise ValueError(
12331241
"Unrecognized indexing keyword value, expecting 'xy' or 'ij.'"

0 commit comments

Comments
 (0)