Skip to content

Commit 50e7775

Browse files
committed
Use EqualityMapping for relevant dtype helpers
1 parent 22f8b75 commit 50e7775

File tree

1 file changed

+86
-75
lines changed

1 file changed

+86
-75
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 86 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,46 @@
3737
]
3838

3939

40+
class EqualityMapping(Mapping):
41+
"""
42+
Mapping that uses equality for indexing
43+
44+
Typical mappings (e.g. the built-in dict) use hashing for indexing. This
45+
isn't ideal for the Array API, as no __hash__() method is specified for
46+
dtype objects - but __eq__() is!
47+
48+
See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects
49+
"""
50+
51+
def __init__(self, mapping: Mapping):
52+
keys = list(mapping.keys())
53+
for i, key in enumerate(keys):
54+
if not (key == key): # specifically checking __eq__, not __neq__
55+
raise ValueError("Key {key!r} does not have equality with itself")
56+
other_keys = keys[:]
57+
other_keys.pop(i)
58+
for other_key in other_keys:
59+
if key == other_key:
60+
raise ValueError("Key {key!r} has equality with key {other_key!r}")
61+
self._mapping = mapping
62+
63+
def __getitem__(self, key):
64+
for k, v in self._mapping.items():
65+
if key == k:
66+
return v
67+
else:
68+
raise KeyError(f"{key!r} not found")
69+
70+
def __iter__(self):
71+
return iter(self._mapping)
72+
73+
def __len__(self):
74+
return len(self._mapping)
75+
76+
def __repr__(self):
77+
return f"EqualityMapping({self._mapping!r})"
78+
79+
4080
_uint_names = ("uint8", "uint16", "uint32", "uint64")
4181
_int_names = ("int8", "int16", "int32", "int64")
4282
_float_names = ("float32", "float64")
@@ -52,14 +92,16 @@
5292
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes
5393

5494

55-
dtype_to_name = {getattr(xp, name): name for name in _dtype_names}
95+
dtype_to_name = EqualityMapping({getattr(xp, name): name for name in _dtype_names})
5696

5797

58-
dtype_to_scalars = {
59-
xp.bool: [bool],
60-
**{d: [int] for d in all_int_dtypes},
61-
**{d: [int, float] for d in float_dtypes},
62-
}
98+
dtype_to_scalars = EqualityMapping(
99+
{
100+
xp.bool: [bool],
101+
**{d: [int] for d in all_int_dtypes},
102+
**{d: [int, float] for d in float_dtypes},
103+
}
104+
)
63105

64106

65107
def is_int_dtype(dtype):
@@ -91,31 +133,37 @@ class MinMax(NamedTuple):
91133
max: Union[int, float]
92134

93135

94-
dtype_ranges = {
95-
xp.int8: MinMax(-128, +127),
96-
xp.int16: MinMax(-32_768, +32_767),
97-
xp.int32: MinMax(-2_147_483_648, +2_147_483_647),
98-
xp.int64: MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807),
99-
xp.uint8: MinMax(0, +255),
100-
xp.uint16: MinMax(0, +65_535),
101-
xp.uint32: MinMax(0, +4_294_967_295),
102-
xp.uint64: MinMax(0, +18_446_744_073_709_551_615),
103-
xp.float32: MinMax(-3.4028234663852886e38, 3.4028234663852886e38),
104-
xp.float64: MinMax(-1.7976931348623157e308, 1.7976931348623157e308),
105-
}
106-
107-
dtype_nbits = {
108-
**{d: 8 for d in [xp.int8, xp.uint8]},
109-
**{d: 16 for d in [xp.int16, xp.uint16]},
110-
**{d: 32 for d in [xp.int32, xp.uint32, xp.float32]},
111-
**{d: 64 for d in [xp.int64, xp.uint64, xp.float64]},
112-
}
113-
114-
115-
dtype_signed = {
116-
**{d: True for d in int_dtypes},
117-
**{d: False for d in uint_dtypes},
118-
}
136+
dtype_ranges = EqualityMapping(
137+
{
138+
xp.int8: MinMax(-128, +127),
139+
xp.int16: MinMax(-32_768, +32_767),
140+
xp.int32: MinMax(-2_147_483_648, +2_147_483_647),
141+
xp.int64: MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807),
142+
xp.uint8: MinMax(0, +255),
143+
xp.uint16: MinMax(0, +65_535),
144+
xp.uint32: MinMax(0, +4_294_967_295),
145+
xp.uint64: MinMax(0, +18_446_744_073_709_551_615),
146+
xp.float32: MinMax(-3.4028234663852886e38, 3.4028234663852886e38),
147+
xp.float64: MinMax(-1.7976931348623157e308, 1.7976931348623157e308),
148+
}
149+
)
150+
151+
dtype_nbits = EqualityMapping(
152+
{
153+
**{d: 8 for d in [xp.int8, xp.uint8]},
154+
**{d: 16 for d in [xp.int16, xp.uint16]},
155+
**{d: 32 for d in [xp.int32, xp.uint32, xp.float32]},
156+
**{d: 64 for d in [xp.int64, xp.uint64, xp.float64]},
157+
}
158+
)
159+
160+
161+
dtype_signed = EqualityMapping(
162+
{
163+
**{d: True for d in int_dtypes},
164+
**{d: False for d in uint_dtypes},
165+
}
166+
)
119167

120168

121169
if isinstance(xp.asarray, _UndefinedStub):
@@ -179,11 +227,13 @@ class MinMax(NamedTuple):
179227
(xp.float32, xp.float64): xp.float64,
180228
(xp.float64, xp.float64): xp.float64,
181229
}
182-
promotion_table = {
183-
(xp.bool, xp.bool): xp.bool,
184-
**_numeric_promotions,
185-
**{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()},
186-
}
230+
promotion_table = EqualityMapping(
231+
{
232+
(xp.bool, xp.bool): xp.bool,
233+
**_numeric_promotions,
234+
**{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()},
235+
}
236+
)
187237

188238

189239
def result_type(*dtypes: DataType):
@@ -405,42 +455,3 @@ def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str:
405455
# i.e. dtype is bool, int, or float
406456
f_types.append(type_.__name__)
407457
return ", ".join(f_types)
408-
409-
410-
class EqualityMapping(Mapping):
411-
"""
412-
Mapping that uses equality for indexing
413-
414-
Typical mappings (e.g. the built-in dict) use hashing for indexing. This
415-
isn't ideal for the Array API, as no __hash__() method is specified for
416-
dtype objects - but __eq__() is!
417-
418-
See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects
419-
"""
420-
def __init__(self, mapping: Mapping):
421-
keys = list(mapping.keys())
422-
for i, key in enumerate(keys):
423-
if not (key == key): # specifically checking __eq__, not __neq__
424-
raise ValueError("Key {key!r} does not have equality with itself")
425-
other_keys = keys[:]
426-
other_keys.pop(i)
427-
for other_key in other_keys:
428-
if key == other_key:
429-
raise ValueError("Key {key!r} has equality with key {other_key!r}")
430-
self._mapping = mapping
431-
432-
def __getitem__(self, key):
433-
for k, v in self._mapping.items():
434-
if key == k:
435-
return v
436-
else:
437-
raise KeyError(f"{key!r} not found")
438-
439-
def __iter__(self):
440-
return iter(self._mapping)
441-
442-
def __len__(self):
443-
return len(self._mapping)
444-
445-
def __repr__(self):
446-
return f"EqualityMapping({self._mapping!r})"

0 commit comments

Comments
 (0)