Skip to content

Commit 06fc784

Browse files
committed
Construct inplace_op_to_symbol in dtype_helpers
1 parent 8e79ebf commit 06fc784

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@
2323
'func_out_categories',
2424
'op_in_categories',
2525
'op_out_categories',
26+
'op_to_func',
2627
'binary_op_to_symbol',
2728
'unary_op_to_symbol',
28-
'op_to_func',
29+
'inplace_op_to_symbol',
2930
]
3031

3132

@@ -328,6 +329,16 @@ class MinMax(NamedTuple):
328329

329330
op_in_categories = {}
330331
op_out_categories = {}
331-
for op_func, elwise_func in op_to_func.items():
332-
op_in_categories[op_func] = func_in_categories[elwise_func]
333-
op_out_categories[op_func] = func_out_categories[elwise_func]
332+
for op, elwise_func in op_to_func.items():
333+
op_in_categories[op] = func_in_categories[elwise_func]
334+
op_out_categories[op] = func_out_categories[elwise_func]
335+
336+
337+
inplace_op_to_symbol = {}
338+
for op, symbol in binary_op_to_symbol.items():
339+
if op == '__matmul__' or op_out_categories[op] == 'bool':
340+
continue
341+
iop = f'__i{op[2:]}'
342+
inplace_op_to_symbol[iop] = f'{symbol}='
343+
op_in_categories[iop] = op_in_categories[op]
344+
op_out_categories[iop] = op_out_categories[op]

array_api_tests/test_type_promotion.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -168,24 +168,23 @@ def test_operator_returns_array_with_correct_dtype(
168168

169169

170170
def gen_inplace_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]:
171-
for op, symbol in dh.binary_op_to_symbol.items():
172-
if op == '__matmul__' or dh.op_out_categories[op] == 'bool':
171+
for op, symbol in dh.inplace_op_to_symbol.items():
172+
if op == '__imatmul__':
173173
continue
174174
in_category = dh.op_in_categories[op]
175175
valid_in_dtypes = dh.category_to_dtypes[in_category]
176-
iop = f'__i{op[2:]}'
177176
for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items():
178177
if (
179178
in_dtype1 == promoted_dtype
180179
and in_dtype1 in valid_in_dtypes
181180
and in_dtype2 in valid_in_dtypes
182181
):
183182
yield pytest.param(
184-
f'x1 {symbol}= x2',
183+
f'x1 {symbol} x2',
185184
(in_dtype1, in_dtype2),
186185
promoted_dtype,
187-
filters[iop],
188-
id=f'{iop}({in_dtype1}, {in_dtype2}) -> {promoted_dtype}',
186+
filters[op],
187+
id=f'{op}({in_dtype1}, {in_dtype2}) -> {promoted_dtype}',
189188
)
190189

191190

@@ -252,19 +251,18 @@ def test_binary_operator_promotes_python_scalars(
252251

253252

254253
def gen_inplace_scalar_params() -> Iterator[Tuple[str, DT, ScalarType, Callable]]:
255-
for op, symbol in dh.binary_op_to_symbol.items():
256-
if op == '__matmul__' or dh.op_out_categories[op] == 'bool':
254+
for op, symbol in dh.inplace_op_to_symbol.items():
255+
if op == '__imatmul__':
257256
continue
258257
in_category = dh.op_in_categories[op]
259-
iop = f'__i{op[2:]}'
260258
for dtype in dh.category_to_dtypes[in_category]:
261259
for in_stype in dh.dtypes_to_scalars[dtype]:
262260
yield pytest.param(
263-
f'x {symbol}= s',
261+
f'x {symbol} s',
264262
dtype,
265263
in_stype,
266-
filters[iop],
267-
id=f'{iop}({dtype}, {in_stype.__name__}) -> {dtype}',
264+
filters[op],
265+
id=f'{op}({dtype}, {in_stype.__name__}) -> {dtype}',
268266
)
269267

270268

0 commit comments

Comments
 (0)