Skip to content

Commit ddb68e8

Browse files
committed
Implemented dpctl.tensor.meshgrid and tests
1 parent 7b368a1 commit ddb68e8

File tree

3 files changed

+78
-0
lines changed

3 files changed

+78
-0
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
full,
3232
full_like,
3333
linspace,
34+
meshgrid,
3435
ones,
3536
ones_like,
3637
tril,
@@ -87,4 +88,5 @@
8788
"from_dlpack",
8889
"tril",
8990
"triu",
91+
"meshgrid",
9092
]

dpctl/tensor/_ctors.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,3 +1198,53 @@ def triu(X, k=0):
11981198
hev.wait()
11991199

12001200
return res
1201+
1202+
1203+
def meshgrid(*arrays, indexing="xy"):
1204+
1205+
"""
1206+
meshgrid(*arrays, indexing="xy") -> list[usm_ndarray]
1207+
1208+
Creates list of `usm_ndarray` coordinate matrices from vectors.
1209+
1210+
Args:
1211+
arrays: arbitrary number of one-dimensional `USM_ndarray` objects.
1212+
If vectors are not of the same data type,
1213+
or are not one-dimensional, raises `ValueError.`
1214+
indexing: Cartesian (`xy`) or matrix (`ij`) indexing of output.
1215+
For a set of `n` vectors with lengths X0, X1, X2, ...
1216+
Cartesian indexing results in arrays of shape
1217+
(X1, X0, X2, ...)
1218+
matrix indexing results in arrays of shape
1219+
(X0, X1, X2, ...)
1220+
Default: `xy`.
1221+
"""
1222+
for array in arrays:
1223+
if not isinstance(array, dpt.usm_ndarray):
1224+
raise TypeError(
1225+
f"Expected instance of dpt.usm_ndarray, got {type(array)}."
1226+
)
1227+
if array.ndim != 1:
1228+
raise ValueError("All arrays must be one-dimensional.")
1229+
if len(set([array.dtype for array in arrays])) > 1:
1230+
raise ValueError("All arrays must be of the same numeric data type.")
1231+
if indexing not in ["xy", "ij"]:
1232+
raise ValueError(
1233+
"Unrecognized indexing keyword value, expecting 'xy' or 'ij.'"
1234+
)
1235+
n = len(arrays)
1236+
sh = (-1,) + (1,) * (n - 1)
1237+
1238+
res = []
1239+
if n > 1 and indexing == "xy":
1240+
res.append(dpt.reshape(arrays[0], (1, -1) + sh[2:], copy=True))
1241+
res.append(dpt.reshape(arrays[1], sh, copy=True))
1242+
arrays, sh = arrays[2:], sh[-2:] + sh[:-2]
1243+
1244+
for array in arrays:
1245+
res.append(dpt.reshape(array, sh, copy=True))
1246+
sh = sh[-1:] + sh[:-1]
1247+
1248+
output = dpt.broadcast_arrays(*res)
1249+
1250+
return output

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1427,6 +1427,30 @@ def test_tril_order_k(order, k):
14271427
assert np.array_equal(Ynp, dpt.asnumpy(Y))
14281428

14291429

1430+
def test_meshgrid():
1431+
try:
1432+
q = dpctl.SyclQueue()
1433+
except dpctl.SyclQueueCreationError:
1434+
pytest.skip("Queue could not be created")
1435+
X = dpt.arange(5, sycl_queue=q)
1436+
Y = dpt.arange(3, sycl_queue=q)
1437+
Z = dpt.meshgrid(X, Y)
1438+
Znp = np.meshgrid(dpt.asnumpy(X), dpt.asnumpy(Y))
1439+
n = len(Z)
1440+
assert n == len(Znp)
1441+
for i in range(n):
1442+
assert np.array_equal(dpt.asnumpy(Z[i]), Znp[i])
1443+
# dimension > 1 must raise ValueError
1444+
with pytest.raises(ValueError):
1445+
dpt.meshgrid(dpt.usm_ndarray((4, 4)))
1446+
# unknown indexing kwarg must raise ValueError
1447+
with pytest.raises(ValueError):
1448+
dpt.meshgrid(X, indexing="ji")
1449+
# input arrays with different data types must raise ValueError
1450+
with pytest.raises(ValueError):
1451+
dpt.meshgrid(X, dpt.asarray(Y, dtype="b1"))
1452+
1453+
14301454
def test_common_arg_validation():
14311455
order = "I"
14321456
# invalid order must raise ValueError
@@ -1463,3 +1487,5 @@ def test_common_arg_validation():
14631487
dpt.tril(X)
14641488
with pytest.raises(TypeError):
14651489
dpt.triu(X)
1490+
with pytest.raises(TypeError):
1491+
dpt.meshgrid(X)

0 commit comments

Comments
 (0)