Skip to content

Commit 647a5f0

Browse files
committed
Add tests for from_dlpack and __dlpack__ 2023.12 behavior
1 parent dc4684b commit 647a5f0

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

array_api_strict/tests/test_array_object.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
uint64,
2424
bool as bool_,
2525
)
26+
from .._flags import set_array_api_strict_flags
27+
2628
import array_api_strict
2729

2830
def test_validate_index():
@@ -420,3 +422,33 @@ def test_array_namespace():
420422

421423
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
422424
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12"))
425+
426+
427+
@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])
428+
def dlpack_2023_12(api_version):
429+
if api_version != '2022.12':
430+
with pytest.warns(UserWarning):
431+
set_array_api_strict_flags(api_version=api_version)
432+
else:
433+
set_array_api_strict_flags(api_version=api_version)
434+
435+
a = asarray([1, 2, 3], dtype=int8)
436+
# Never an error
437+
a.__dlpack__()
438+
439+
440+
exception = NotImplementedError if api_version >= '2023.12' else ValueError
441+
pytest.raises(exception, lambda:
442+
a.__dlpack__(dl_device=CPU_DEVICE))
443+
pytest.raises(exception, lambda:
444+
a.__dlpack__(dl_device=None))
445+
pytest.raises(exception, lambda:
446+
a.__dlpack__(max_version=(1, 0)))
447+
pytest.raises(exception, lambda:
448+
a.__dlpack__(max_version=None))
449+
pytest.raises(exception, lambda:
450+
a.__dlpack__(copy=False))
451+
pytest.raises(exception, lambda:
452+
a.__dlpack__(copy=True))
453+
pytest.raises(exception, lambda:
454+
a.__dlpack__(copy=None))

array_api_strict/tests/test_creation_functions.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
from numpy.testing import assert_raises
44
import numpy as np
55

6+
import pytest
7+
68
from .. import all
79
from .._creation_functions import (
810
asarray,
911
arange,
1012
empty,
1113
empty_like,
1214
eye,
15+
from_dlpack,
1316
full,
1417
full_like,
1518
linspace,
@@ -21,7 +24,7 @@
2124
)
2225
from .._dtypes import float32, float64
2326
from .._array_object import Array, CPU_DEVICE
24-
27+
from .._flags import set_array_api_strict_flags
2528

2629
def test_asarray_errors():
2730
# Test various protections against incorrect usage
@@ -188,3 +191,24 @@ def test_meshgrid_dtype_errors():
188191
meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float32))
189192

190193
assert_raises(ValueError, lambda: meshgrid(asarray([1.], dtype=float32), asarray([1.], dtype=float64)))
194+
195+
196+
@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])
197+
def from_dlpack_2023_12(api_version):
198+
if api_version != '2022.12':
199+
with pytest.warns(UserWarning):
200+
set_array_api_strict_flags(api_version=api_version)
201+
else:
202+
set_array_api_strict_flags(api_version=api_version)
203+
204+
a = asarray([1., 2., 3.], dtype=float64)
205+
# Never an error
206+
capsule = a.__dlpack__()
207+
from_dlpack(capsule)
208+
209+
exception = NotImplementedError if api_version >= '2023.12' else ValueError
210+
pytest.raises(exception, lambda: from_dlpack(capsule, device=CPU_DEVICE))
211+
pytest.raises(exception, lambda: from_dlpack(capsule, device=None))
212+
pytest.raises(exception, lambda: from_dlpack(capsule, copy=False))
213+
pytest.raises(exception, lambda: from_dlpack(capsule, copy=True))
214+
pytest.raises(exception, lambda: from_dlpack(capsule, copy=None))

0 commit comments

Comments
 (0)