Skip to content

Commit 6cc7bac

Browse files
committed
Define __hash__
1 parent 325b9d0 commit 6cc7bac

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

array_api_strict/_array_object.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,14 @@ def __repr__(self):
5353
return f"array_api_strict.Device('{self._device}')"
5454

5555
def __eq__(self, other):
56+
if not isinstance(other, Device):
57+
return False
5658
return self._device == other._device
5759

60+
def __hash__(self):
61+
return hash(("Device", self._device))
62+
63+
5864
CPU_DEVICE = Device()
5965

6066
_default = object()

array_api_strict/_creation_functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,11 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
307307
if len({a.dtype for a in arrays}) > 1:
308308
raise ValueError("meshgrid inputs must all have the same dtype")
309309

310+
if len({a.device for a in arrays}) > 1:
311+
raise ValueError("meshgrid inputs must all be on the same device")
312+
310313
return [
311-
Array._new(array, device=device)
314+
Array._new(array, device=array.device)
312315
for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
313316
]
314317

0 commit comments

Comments
 (0)