|
37 | 37 | ]
|
38 | 38 |
|
39 | 39 |
|
| 40 | +class EqualityMapping(Mapping): |
| 41 | + """ |
| 42 | + Mapping that uses equality for indexing |
| 43 | +
|
| 44 | + Typical mappings (e.g. the built-in dict) use hashing for indexing. This |
| 45 | + isn't ideal for the Array API, as no __hash__() method is specified for |
| 46 | + dtype objects - but __eq__() is! |
| 47 | +
|
| 48 | + See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects |
| 49 | + """ |
| 50 | + |
| 51 | + def __init__(self, mapping: Mapping): |
| 52 | + keys = list(mapping.keys()) |
| 53 | + for i, key in enumerate(keys): |
| 54 | + if not (key == key): # specifically checking __eq__, not __neq__ |
| 55 | + raise ValueError("Key {key!r} does not have equality with itself") |
| 56 | + other_keys = keys[:] |
| 57 | + other_keys.pop(i) |
| 58 | + for other_key in other_keys: |
| 59 | + if key == other_key: |
| 60 | + raise ValueError("Key {key!r} has equality with key {other_key!r}") |
| 61 | + self._mapping = mapping |
| 62 | + |
| 63 | + def __getitem__(self, key): |
| 64 | + for k, v in self._mapping.items(): |
| 65 | + if key == k: |
| 66 | + return v |
| 67 | + else: |
| 68 | + raise KeyError(f"{key!r} not found") |
| 69 | + |
| 70 | + def __iter__(self): |
| 71 | + return iter(self._mapping) |
| 72 | + |
| 73 | + def __len__(self): |
| 74 | + return len(self._mapping) |
| 75 | + |
| 76 | + def __repr__(self): |
| 77 | + return f"EqualityMapping({self._mapping!r})" |
| 78 | + |
| 79 | + |
40 | 80 | _uint_names = ("uint8", "uint16", "uint32", "uint64")
|
41 | 81 | _int_names = ("int8", "int16", "int32", "int64")
|
42 | 82 | _float_names = ("float32", "float64")
|
|
52 | 92 | bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes
|
53 | 93 |
|
54 | 94 |
|
55 |
| -dtype_to_name = {getattr(xp, name): name for name in _dtype_names} |
| 95 | +dtype_to_name = EqualityMapping({getattr(xp, name): name for name in _dtype_names}) |
56 | 96 |
|
57 | 97 |
|
58 |
| -dtype_to_scalars = { |
59 |
| - xp.bool: [bool], |
60 |
| - **{d: [int] for d in all_int_dtypes}, |
61 |
| - **{d: [int, float] for d in float_dtypes}, |
62 |
| -} |
| 98 | +dtype_to_scalars = EqualityMapping( |
| 99 | + { |
| 100 | + xp.bool: [bool], |
| 101 | + **{d: [int] for d in all_int_dtypes}, |
| 102 | + **{d: [int, float] for d in float_dtypes}, |
| 103 | + } |
| 104 | +) |
63 | 105 |
|
64 | 106 |
|
65 | 107 | def is_int_dtype(dtype):
|
@@ -91,31 +133,37 @@ class MinMax(NamedTuple):
|
91 | 133 | max: Union[int, float]
|
92 | 134 |
|
93 | 135 |
|
94 |
| -dtype_ranges = { |
95 |
| - xp.int8: MinMax(-128, +127), |
96 |
| - xp.int16: MinMax(-32_768, +32_767), |
97 |
| - xp.int32: MinMax(-2_147_483_648, +2_147_483_647), |
98 |
| - xp.int64: MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807), |
99 |
| - xp.uint8: MinMax(0, +255), |
100 |
| - xp.uint16: MinMax(0, +65_535), |
101 |
| - xp.uint32: MinMax(0, +4_294_967_295), |
102 |
| - xp.uint64: MinMax(0, +18_446_744_073_709_551_615), |
103 |
| - xp.float32: MinMax(-3.4028234663852886e38, 3.4028234663852886e38), |
104 |
| - xp.float64: MinMax(-1.7976931348623157e308, 1.7976931348623157e308), |
105 |
| -} |
106 |
| - |
107 |
| -dtype_nbits = { |
108 |
| - **{d: 8 for d in [xp.int8, xp.uint8]}, |
109 |
| - **{d: 16 for d in [xp.int16, xp.uint16]}, |
110 |
| - **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, |
111 |
| - **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, |
112 |
| -} |
113 |
| - |
114 |
| - |
115 |
| -dtype_signed = { |
116 |
| - **{d: True for d in int_dtypes}, |
117 |
| - **{d: False for d in uint_dtypes}, |
118 |
| -} |
| 136 | +dtype_ranges = EqualityMapping( |
| 137 | + { |
| 138 | + xp.int8: MinMax(-128, +127), |
| 139 | + xp.int16: MinMax(-32_768, +32_767), |
| 140 | + xp.int32: MinMax(-2_147_483_648, +2_147_483_647), |
| 141 | + xp.int64: MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807), |
| 142 | + xp.uint8: MinMax(0, +255), |
| 143 | + xp.uint16: MinMax(0, +65_535), |
| 144 | + xp.uint32: MinMax(0, +4_294_967_295), |
| 145 | + xp.uint64: MinMax(0, +18_446_744_073_709_551_615), |
| 146 | + xp.float32: MinMax(-3.4028234663852886e38, 3.4028234663852886e38), |
| 147 | + xp.float64: MinMax(-1.7976931348623157e308, 1.7976931348623157e308), |
| 148 | + } |
| 149 | +) |
| 150 | + |
| 151 | +dtype_nbits = EqualityMapping( |
| 152 | + { |
| 153 | + **{d: 8 for d in [xp.int8, xp.uint8]}, |
| 154 | + **{d: 16 for d in [xp.int16, xp.uint16]}, |
| 155 | + **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, |
| 156 | + **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, |
| 157 | + } |
| 158 | +) |
| 159 | + |
| 160 | + |
| 161 | +dtype_signed = EqualityMapping( |
| 162 | + { |
| 163 | + **{d: True for d in int_dtypes}, |
| 164 | + **{d: False for d in uint_dtypes}, |
| 165 | + } |
| 166 | +) |
119 | 167 |
|
120 | 168 |
|
121 | 169 | if isinstance(xp.asarray, _UndefinedStub):
|
@@ -179,11 +227,13 @@ class MinMax(NamedTuple):
|
179 | 227 | (xp.float32, xp.float64): xp.float64,
|
180 | 228 | (xp.float64, xp.float64): xp.float64,
|
181 | 229 | }
|
182 |
| -promotion_table = { |
183 |
| - (xp.bool, xp.bool): xp.bool, |
184 |
| - **_numeric_promotions, |
185 |
| - **{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()}, |
186 |
| -} |
| 230 | +promotion_table = EqualityMapping( |
| 231 | + { |
| 232 | + (xp.bool, xp.bool): xp.bool, |
| 233 | + **_numeric_promotions, |
| 234 | + **{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()}, |
| 235 | + } |
| 236 | +) |
187 | 237 |
|
188 | 238 |
|
189 | 239 | def result_type(*dtypes: DataType):
|
@@ -405,42 +455,3 @@ def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str:
|
405 | 455 | # i.e. dtype is bool, int, or float
|
406 | 456 | f_types.append(type_.__name__)
|
407 | 457 | return ", ".join(f_types)
|
408 |
| - |
409 |
| - |
410 |
| -class EqualityMapping(Mapping): |
411 |
| - """ |
412 |
| - Mapping that uses equality for indexing |
413 |
| -
|
414 |
| - Typical mappings (e.g. the built-in dict) use hashing for indexing. This |
415 |
| - isn't ideal for the Array API, as no __hash__() method is specified for |
416 |
| - dtype objects - but __eq__() is! |
417 |
| -
|
418 |
| - See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects |
419 |
| - """ |
420 |
| - def __init__(self, mapping: Mapping): |
421 |
| - keys = list(mapping.keys()) |
422 |
| - for i, key in enumerate(keys): |
423 |
| - if not (key == key): # specifically checking __eq__, not __neq__ |
424 |
| - raise ValueError("Key {key!r} does not have equality with itself") |
425 |
| - other_keys = keys[:] |
426 |
| - other_keys.pop(i) |
427 |
| - for other_key in other_keys: |
428 |
| - if key == other_key: |
429 |
| - raise ValueError("Key {key!r} has equality with key {other_key!r}") |
430 |
| - self._mapping = mapping |
431 |
| - |
432 |
| - def __getitem__(self, key): |
433 |
| - for k, v in self._mapping.items(): |
434 |
| - if key == k: |
435 |
| - return v |
436 |
| - else: |
437 |
| - raise KeyError(f"{key!r} not found") |
438 |
| - |
439 |
| - def __iter__(self): |
440 |
| - return iter(self._mapping) |
441 |
| - |
442 |
| - def __len__(self): |
443 |
| - return len(self._mapping) |
444 |
| - |
445 |
| - def __repr__(self): |
446 |
| - return f"EqualityMapping({self._mapping!r})" |
|
0 commit comments