Skip to content

Commit 426609f

Browse files
committed
Fix meshgrid
1 parent bca670a commit 426609f

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

array_api_strict/_creation_functions.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,14 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
310310
if len({a.device for a in arrays}) > 1:
311311
raise ValueError("meshgrid inputs must all be on the same device")
312312

313+
# arrays is allowed to be empty
314+
if arrays:
315+
device = arrays[0].device
316+
else:
317+
device = None
318+
313319
return [
314-
Array._new(array, device=array.device)
320+
Array._new(array, device=device)
315321
for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
316322
]
317323

0 commit comments

Comments
 (0)