|
| 1 | +import numpy as np |
| 2 | +try: |
| 3 | + import cupy as cp |
| 4 | +except ImportError: |
| 5 | + cp = None |
| 6 | +import pytest |
| 7 | + |
| 8 | +from cuda.core.experimental import Device |
| 9 | +from cuda.core.experimental.utils import StridedMemoryView, viewable |
| 10 | + |
| 11 | + |
| 12 | +@pytest.mark.parametrize( |
| 13 | + "in_arr,", ( |
| 14 | + np.empty(3, dtype=np.int32), |
| 15 | + np.empty((6, 6), dtype=np.float64)[::2, ::2], |
| 16 | + np.empty((3, 4), order='F'), |
| 17 | + ) |
| 18 | +) |
| 19 | +def test_viewable_cpu(in_arr): |
| 20 | + |
| 21 | + @viewable((0,)) |
| 22 | + def my_func(arr): |
| 23 | + view = arr.view(-1) |
| 24 | + assert view.ptr == in_arr.ctypes.data |
| 25 | + assert view.shape == in_arr.shape |
| 26 | + if in_arr.flags.c_contiguous: |
| 27 | + assert view.strides is None |
| 28 | + else: |
| 29 | + assert view.strides == tuple(s // in_arr.dtype.itemsize for s in in_arr.strides) |
| 30 | + assert view.dtype == in_arr.dtype |
| 31 | + assert view.device_id == 0 |
| 32 | + assert view.device_accessible == False |
| 33 | + assert view.exporting_obj is in_arr |
| 34 | + |
| 35 | + my_func(in_arr) |
| 36 | + |
| 37 | + |
| 38 | +if cp is not None: |
| 39 | + |
| 40 | + @pytest.mark.parametrize( |
| 41 | + "in_arr,stream", ( |
| 42 | + (cp.empty(3, dtype=cp.complex64), None), |
| 43 | + (cp.empty((6, 6), dtype=cp.float64)[::2, ::2], True), |
| 44 | + (cp.empty((3, 4), order='F'), True), |
| 45 | + ) |
| 46 | + ) |
| 47 | + def test_viewable_gpu(in_arr, stream): |
| 48 | + # TODO: use the device fixture? |
| 49 | + dev = Device() |
| 50 | + dev.set_current() |
| 51 | + s = dev.create_stream() if stream else None |
| 52 | + |
| 53 | + @viewable((0,)) |
| 54 | + def my_func(arr): |
| 55 | + view = arr.view(s.handle if s else -1) |
| 56 | + assert view.ptr == in_arr.data.ptr |
| 57 | + assert view.shape == in_arr.shape |
| 58 | + strides_in_counts = tuple(s // in_arr.dtype.itemsize for s in in_arr.strides) |
| 59 | + if in_arr.flags.c_contiguous: |
| 60 | + assert view.strides in (None, strides_in_counts) |
| 61 | + else: |
| 62 | + assert view.strides == strides_in_counts |
| 63 | + assert view.dtype == in_arr.dtype |
| 64 | + assert view.device_id == dev.device_id |
| 65 | + assert view.device_accessible == True |
| 66 | + assert view.exporting_obj is in_arr |
| 67 | + |
| 68 | + my_func(in_arr) |
0 commit comments