Skip to content

Commit a544741

Browse files
committed
EqualityMapping class
1 parent 9816011 commit a544741

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import Mapping
12
from functools import lru_cache
23
from typing import NamedTuple, Tuple, Union
34
from warnings import warn
@@ -99,8 +100,8 @@ class MinMax(NamedTuple):
99100
xp.uint16: MinMax(0, +65_535),
100101
xp.uint32: MinMax(0, +4_294_967_295),
101102
xp.uint64: MinMax(0, +18_446_744_073_709_551_615),
102-
xp.float32: MinMax(-3.4028234663852886e+38, 3.4028234663852886e+38),
103-
xp.float64: MinMax(-1.7976931348623157e+308, 1.7976931348623157e+308),
103+
xp.float32: MinMax(-3.4028234663852886e38, 3.4028234663852886e38),
104+
xp.float64: MinMax(-1.7976931348623157e308, 1.7976931348623157e308),
104105
}
105106

106107
dtype_nbits = {
@@ -404,3 +405,28 @@ def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str:
404405
# i.e. dtype is bool, int, or float
405406
f_types.append(type_.__name__)
406407
return ", ".join(f_types)
408+
409+
410+
class EqualityMapping(Mapping):
411+
def __init__(self, mapping: Mapping):
412+
keys = list(mapping.keys())
413+
for i, key in enumerate(keys):
414+
if not (key == key): # specifically test __eq__, not __neq__
415+
raise ValueError("Key {key!r} does not have equality with itself")
416+
other_keys = keys[:]
417+
other_keys.pop(i)
418+
for other_key in other_keys:
419+
if key == other_key:
420+
raise ValueError("Key {key!r} has equality with key {other_key!r}")
421+
self._mapping = mapping
422+
423+
def __getitem__(self, key):
424+
for k, v in self._mapping.items():
425+
if key == k:
426+
return v
427+
428+
def __iter__(self):
429+
return iter(self._mapping)
430+
431+
def __len__(self):
432+
return len(self._mapping)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
3+
from ..dtype_helpers import EqualityMapping
4+
5+
6+
def test_raises_on_distinct_eq_key():
7+
with pytest.raises(ValueError):
8+
EqualityMapping({float("nan"): "foo"})
9+
10+
11+
def test_raises_on_indistinct_eq_keys():
12+
class AlwaysEq:
13+
def __init__(self, hash):
14+
self._hash = hash
15+
16+
def __eq__(self, other):
17+
return True
18+
19+
def __hash__(self):
20+
return self._hash
21+
22+
with pytest.raises(ValueError):
23+
EqualityMapping({AlwaysEq(0): "foo", AlwaysEq(1): "bar"})

0 commit comments

Comments
 (0)