Skip to content

Commit 60b7a93

Browse files
authored
Merge branch 'master' into flags-helper-class
2 parents de3bdcb + 72daccc commit 60b7a93

File tree

5 files changed

+117
-2
lines changed

5 files changed

+117
-2
lines changed

dpctl/memory/_memory.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,6 @@ cdef public api class MemoryUSMHost(_Memory) [object PyMemoryUSMHostObject,
7575
pass
7676

7777

78-
cdef public class MemoryUSMDevice(_Memory) [object PyMemoryUSMDeviceObject,
78+
cdef public api class MemoryUSMDevice(_Memory) [object PyMemoryUSMDeviceObject,
7979
type PyMemoryUSMDeviceType]:
8080
pass

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
full,
3232
full_like,
3333
linspace,
34+
meshgrid,
3435
ones,
3536
ones_like,
3637
tril,
@@ -87,4 +88,5 @@
8788
"from_dlpack",
8889
"tril",
8990
"triu",
91+
"meshgrid",
9092
]

dpctl/tensor/_ctors.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,3 +1198,61 @@ def triu(X, k=0):
11981198
hev.wait()
11991199

12001200
return res
1201+
1202+
1203+
def meshgrid(*arrays, indexing="xy"):
1204+
1205+
"""
1206+
meshgrid(*arrays, indexing="xy") -> list[usm_ndarray]
1207+
1208+
Creates list of `usm_ndarray` coordinate matrices from vectors.
1209+
1210+
Args:
1211+
arrays: arbitrary number of one-dimensional `USM_ndarray` objects.
1212+
If vectors are not of the same data type,
1213+
or are not one-dimensional, raises `ValueError.`
1214+
indexing: Cartesian (`xy`) or matrix (`ij`) indexing of output.
1215+
For a set of `n` vectors with lengths N0, N1, N2, ...
1216+
Cartesian indexing results in arrays of shape
1217+
(N1, N0, N2, ...)
1218+
matrix indexing results in arrays of shape
1219+
(n0, N1, N2, ...)
1220+
Default: `xy`.
1221+
"""
1222+
ref_dt = None
1223+
ref_unset = True
1224+
for array in arrays:
1225+
if not isinstance(array, dpt.usm_ndarray):
1226+
raise TypeError(
1227+
f"Expected instance of dpt.usm_ndarray, got {type(array)}."
1228+
)
1229+
if array.ndim != 1:
1230+
raise ValueError("All arrays must be one-dimensional.")
1231+
if ref_unset:
1232+
ref_unset = False
1233+
ref_dt = array.dtype
1234+
else:
1235+
if not ref_dt == array.dtype:
1236+
raise ValueError(
1237+
"All arrays must be of the same numeric data type."
1238+
)
1239+
if indexing not in ["xy", "ij"]:
1240+
raise ValueError(
1241+
"Unrecognized indexing keyword value, expecting 'xy' or 'ij.'"
1242+
)
1243+
n = len(arrays)
1244+
sh = (-1,) + (1,) * (n - 1)
1245+
1246+
res = []
1247+
if n > 1 and indexing == "xy":
1248+
res.append(dpt.reshape(arrays[0], (1, -1) + sh[2:], copy=True))
1249+
res.append(dpt.reshape(arrays[1], sh, copy=True))
1250+
arrays, sh = arrays[2:], sh[-2:] + sh[:-2]
1251+
1252+
for array in arrays:
1253+
res.append(dpt.reshape(array, sh, copy=True))
1254+
sh = sh[-1:] + sh[:-1]
1255+
1256+
output = dpt.broadcast_arrays(*res)
1257+
1258+
return output

dpctl/tensor/_reshape.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ def reshape(X, newshape, order="C", copy=None):
8686
raise TypeError
8787
if not isinstance(newshape, (list, tuple)):
8888
newshape = (newshape,)
89-
if order not in ["C", "F"]:
89+
if order in "cfCF":
90+
order = order.upper()
91+
else:
9092
raise ValueError(
9193
f"Keyword 'order' not recognized. Expecting 'C' or 'F', got {order}"
9294
)

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1427,6 +1427,57 @@ def test_tril_order_k(order, k):
14271427
assert np.array_equal(Ynp, dpt.asnumpy(Y))
14281428

14291429

1430+
def test_meshgrid():
1431+
try:
1432+
q = dpctl.SyclQueue()
1433+
except dpctl.SyclQueueCreationError:
1434+
pytest.skip("Queue could not be created")
1435+
X = dpt.arange(5, sycl_queue=q)
1436+
Y = dpt.arange(3, sycl_queue=q)
1437+
Z = dpt.meshgrid(X, Y)
1438+
Znp = np.meshgrid(dpt.asnumpy(X), dpt.asnumpy(Y))
1439+
n = len(Z)
1440+
assert n == len(Znp)
1441+
for i in range(n):
1442+
assert np.array_equal(dpt.asnumpy(Z[i]), Znp[i])
1443+
# dimension > 1 must raise ValueError
1444+
with pytest.raises(ValueError):
1445+
dpt.meshgrid(dpt.usm_ndarray((4, 4)))
1446+
# unknown indexing kwarg must raise ValueError
1447+
with pytest.raises(ValueError):
1448+
dpt.meshgrid(X, indexing="ji")
1449+
# input arrays with different data types must raise ValueError
1450+
with pytest.raises(ValueError):
1451+
dpt.meshgrid(X, dpt.asarray(Y, dtype="b1"))
1452+
1453+
1454+
def test_meshgrid2():
1455+
try:
1456+
q1 = dpctl.SyclQueue()
1457+
q2 = dpctl.SyclQueue()
1458+
q3 = dpctl.SyclQueue()
1459+
except dpctl.SyclQueueCreationError:
1460+
pytest.skip("Queue could not be created")
1461+
x1 = dpt.arange(0, 2, dtype="int16", sycl_queue=q1)
1462+
x2 = dpt.arange(3, 6, dtype="int16", sycl_queue=q2)
1463+
x3 = dpt.arange(6, 10, dtype="int16", sycl_queue=q3)
1464+
y1, y2, y3 = dpt.meshgrid(x1, x2, x3, indexing="xy")
1465+
z1, z2, z3 = dpt.meshgrid(x1, x2, x3, indexing="ij")
1466+
assert all(
1467+
x.sycl_queue == y.sycl_queue for x, y in zip((x1, x2, x3), (y1, y2, y3))
1468+
)
1469+
assert all(
1470+
x.sycl_queue == z.sycl_queue for x, z in zip((x1, x2, x3), (z1, z2, z3))
1471+
)
1472+
assert y1.shape == y2.shape and y2.shape == y3.shape
1473+
assert z1.shape == z2.shape and z2.shape == z3.shape
1474+
assert y1.shape == (len(x2), len(x1), len(x3))
1475+
assert z1.shape == (len(x1), len(x2), len(x3))
1476+
# FIXME: uncomment out once gh-921 is merged
1477+
# assert all(z.flags["C"] for z in (z1, z2, z3))
1478+
# assert all(y.flags["C"] for y in (y1, y2, y3))
1479+
1480+
14301481
def test_common_arg_validation():
14311482
order = "I"
14321483
# invalid order must raise ValueError
@@ -1463,6 +1514,8 @@ def test_common_arg_validation():
14631514
dpt.tril(X)
14641515
with pytest.raises(TypeError):
14651516
dpt.triu(X)
1517+
with pytest.raises(TypeError):
1518+
dpt.meshgrid(X)
14661519

14671520

14681521
def test_flags():

0 commit comments

Comments
 (0)