|
3 | 3 | import pytest
|
4 | 4 |
|
5 | 5 | from numpy.testing import assert_raises
|
6 |
| -import array_api_strict as xp |
7 | 6 | import numpy as np
|
8 | 7 |
|
| 8 | +from .._creation_functions import asarray |
| 9 | +from .._data_type_functions import astype, can_cast, isdtype |
| 10 | +from .._dtypes import ( |
| 11 | + bool, int8, int16, uint8, float64, |
| 12 | +) |
| 13 | +from .._flags import set_array_api_strict_flags |
| 14 | + |
| 15 | + |
9 | 16 | @pytest.mark.parametrize(
|
10 | 17 | "from_, to, expected",
|
11 | 18 | [
|
12 |
| - (xp.int8, xp.int16, True), |
13 |
| - (xp.int16, xp.int8, False), |
14 |
| - (xp.bool, xp.int8, False), |
15 |
| - (xp.asarray(0, dtype=xp.uint8), xp.int8, False), |
| 19 | + (int8, int16, True), |
| 20 | + (int16, int8, False), |
| 21 | + (bool, int8, False), |
| 22 | + (asarray(0, dtype=uint8), int8, False), |
16 | 23 | ],
|
17 | 24 | )
|
18 | 25 | def test_can_cast(from_, to, expected):
|
19 | 26 | """
|
20 | 27 | can_cast() returns correct result
|
21 | 28 | """
|
22 |
| - assert xp.can_cast(from_, to) == expected |
| 29 | + assert can_cast(from_, to) == expected |
23 | 30 |
|
24 | 31 | def test_isdtype_strictness():
|
25 |
| - assert_raises(TypeError, lambda: xp.isdtype(xp.float64, 64)) |
26 |
| - assert_raises(ValueError, lambda: xp.isdtype(xp.float64, 'f8')) |
| 32 | + assert_raises(TypeError, lambda: isdtype(float64, 64)) |
| 33 | + assert_raises(ValueError, lambda: isdtype(float64, 'f8')) |
27 | 34 |
|
28 |
| - assert_raises(TypeError, lambda: xp.isdtype(xp.float64, (('integral',),))) |
| 35 | + assert_raises(TypeError, lambda: isdtype(float64, (('integral',),))) |
29 | 36 | with assert_raises(TypeError), warnings.catch_warnings(record=True) as w:
|
30 | 37 | warnings.simplefilter("always")
|
31 |
| - xp.isdtype(xp.float64, np.object_) |
| 38 | + isdtype(float64, np.object_) |
32 | 39 | assert len(w) == 1
|
33 | 40 | assert issubclass(w[-1].category, UserWarning)
|
34 | 41 |
|
35 |
| - assert_raises(TypeError, lambda: xp.isdtype(xp.float64, None)) |
| 42 | + assert_raises(TypeError, lambda: isdtype(float64, None)) |
36 | 43 | with assert_raises(TypeError), warnings.catch_warnings(record=True) as w:
|
37 | 44 | warnings.simplefilter("always")
|
38 |
| - xp.isdtype(xp.float64, np.float64) |
| 45 | + isdtype(float64, np.float64) |
39 | 46 | assert len(w) == 1
|
40 | 47 | assert issubclass(w[-1].category, UserWarning)
|
| 48 | + |
| 49 | + |
| 50 | +@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) |
| 51 | +def astype_device(api_version): |
| 52 | + if api_version != '2022.12': |
| 53 | + with pytest.warns(UserWarning): |
| 54 | + set_array_api_strict_flags(api_version=api_version) |
| 55 | + else: |
| 56 | + set_array_api_strict_flags(api_version=api_version) |
| 57 | + |
| 58 | + a = asarray([1, 2, 3], dtype=int8) |
| 59 | + # Never an error |
| 60 | + astype(a, int16) |
| 61 | + |
| 62 | + # Always an error |
| 63 | + astype(a, int16, device="cpu") |
| 64 | + |
| 65 | + if api_version >= '2023.12': |
| 66 | + astype(a, int8, device=None) |
| 67 | + astype(a, int8, device=a.device) |
| 68 | + else: |
| 69 | + pytest.raises(TypeError, lambda: astype(a, int8, device=None)) |
| 70 | + pytest.raises(TypeError, lambda: astype(a, int8, device=a.device)) |
0 commit comments