Skip to content

Commit 6b43194

Browse files
committed
Add device flag to astype in 2023.12
Also clean up imports in test_data_type_functions.py
1 parent 47894ff commit 6b43194

File tree

2 files changed

+58
-16
lines changed

2 files changed

+58
-16
lines changed

array_api_strict/_data_type_functions.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from ._array_object import Array
3+
from ._array_object import Array, CPU_DEVICE
44
from ._dtypes import (
55
_DType,
66
_all_dtypes,
@@ -13,19 +13,31 @@
1313
_numeric_dtypes,
1414
_result_type,
1515
)
16+
from ._flags import get_array_api_strict_flags
1617

1718
from dataclasses import dataclass
1819
from typing import TYPE_CHECKING
1920

2021
if TYPE_CHECKING:
21-
from typing import List, Tuple, Union
22-
from ._typing import Dtype
22+
from typing import List, Tuple, Union, Optional
23+
from ._typing import Dtype, Device
2324

2425
import numpy as np
2526

27+
# Use to emulate the asarray(device) argument not existing in 2022.12
28+
_default = object()
2629

2730
# Note: astype is a function, not an array method as in NumPy.
28-
def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array:
31+
def astype(
32+
x: Array, dtype: Dtype, /, *, copy: bool = True, device: Optional[Device] = _default
33+
) -> Array:
34+
if device is not _default:
35+
if get_array_api_strict_flags()['api_version'] >= '2023.12':
36+
if device not in [CPU_DEVICE, None]:
37+
raise ValueError(f"Unsupported device {device!r}")
38+
else:
39+
raise TypeError("The device argument to astype requires the 2023.12 version of the array API")
40+
2941
if not copy and dtype == x.dtype:
3042
return x
3143
return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy))

array_api_strict/tests/test_data_type_functions.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,38 +3,68 @@
33
import pytest
44

55
from numpy.testing import assert_raises
6-
import array_api_strict as xp
76
import numpy as np
87

8+
from .._creation_functions import asarray
9+
from .._data_type_functions import astype, can_cast, isdtype
10+
from .._dtypes import (
11+
bool, int8, int16, uint8, float64,
12+
)
13+
from .._flags import set_array_api_strict_flags
14+
15+
916
@pytest.mark.parametrize(
1017
"from_, to, expected",
1118
[
12-
(xp.int8, xp.int16, True),
13-
(xp.int16, xp.int8, False),
14-
(xp.bool, xp.int8, False),
15-
(xp.asarray(0, dtype=xp.uint8), xp.int8, False),
19+
(int8, int16, True),
20+
(int16, int8, False),
21+
(bool, int8, False),
22+
(asarray(0, dtype=uint8), int8, False),
1623
],
1724
)
1825
def test_can_cast(from_, to, expected):
1926
"""
2027
can_cast() returns correct result
2128
"""
22-
assert xp.can_cast(from_, to) == expected
29+
assert can_cast(from_, to) == expected
2330

2431
def test_isdtype_strictness():
25-
assert_raises(TypeError, lambda: xp.isdtype(xp.float64, 64))
26-
assert_raises(ValueError, lambda: xp.isdtype(xp.float64, 'f8'))
32+
assert_raises(TypeError, lambda: isdtype(float64, 64))
33+
assert_raises(ValueError, lambda: isdtype(float64, 'f8'))
2734

28-
assert_raises(TypeError, lambda: xp.isdtype(xp.float64, (('integral',),)))
35+
assert_raises(TypeError, lambda: isdtype(float64, (('integral',),)))
2936
with assert_raises(TypeError), warnings.catch_warnings(record=True) as w:
3037
warnings.simplefilter("always")
31-
xp.isdtype(xp.float64, np.object_)
38+
isdtype(float64, np.object_)
3239
assert len(w) == 1
3340
assert issubclass(w[-1].category, UserWarning)
3441

35-
assert_raises(TypeError, lambda: xp.isdtype(xp.float64, None))
42+
assert_raises(TypeError, lambda: isdtype(float64, None))
3643
with assert_raises(TypeError), warnings.catch_warnings(record=True) as w:
3744
warnings.simplefilter("always")
38-
xp.isdtype(xp.float64, np.float64)
45+
isdtype(float64, np.float64)
3946
assert len(w) == 1
4047
assert issubclass(w[-1].category, UserWarning)
48+
49+
50+
@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])
51+
def astype_device(api_version):
52+
if api_version != '2022.12':
53+
with pytest.warns(UserWarning):
54+
set_array_api_strict_flags(api_version=api_version)
55+
else:
56+
set_array_api_strict_flags(api_version=api_version)
57+
58+
a = asarray([1, 2, 3], dtype=int8)
59+
# Never an error
60+
astype(a, int16)
61+
62+
# Always an error
63+
astype(a, int16, device="cpu")
64+
65+
if api_version >= '2023.12':
66+
astype(a, int8, device=None)
67+
astype(a, int8, device=a.device)
68+
else:
69+
pytest.raises(TypeError, lambda: astype(a, int8, device=None))
70+
pytest.raises(TypeError, lambda: astype(a, int8, device=a.device))

0 commit comments

Comments
 (0)