|
12 | 12 | SupportsBufferProtocol,
|
13 | 13 | )
|
14 | 14 | from ._dtypes import _DType, _all_dtypes
|
| 15 | +from ._flags import get_array_api_strict_flags |
15 | 16 |
|
16 | 17 | import numpy as np
|
17 | 18 |
|
@@ -174,19 +175,38 @@ def eye(
|
174 | 175 |
|
175 | 176 | See its docstring for more information.
|
176 | 177 | """
|
177 |
| - from ._array_object import Array, CPU_DEVICE |
| 178 | + from ._array_object import Array |
178 | 179 |
|
179 | 180 | _check_valid_dtype(dtype)
|
180 |
| - if device not in [CPU_DEVICE, None]: |
181 |
| - raise ValueError(f"Unsupported device {device!r}") |
| 181 | + _check_device(device) |
| 182 | + |
182 | 183 | if dtype is not None:
|
183 | 184 | dtype = dtype._np_dtype
|
184 | 185 | return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
|
185 | 186 |
|
186 | 187 |
|
187 |
| -def from_dlpack(x: object, /) -> Array: |
| 188 | +_default = object() |
| 189 | + |
| 190 | +def from_dlpack( |
| 191 | + x: object, |
| 192 | + /, |
| 193 | + *, |
| 194 | + device: Optional[Device] = _default, |
| 195 | + copy: Optional[bool] = _default, |
| 196 | +) -> Array: |
188 | 197 | from ._array_object import Array
|
189 | 198 |
|
| 199 | + if get_array_api_strict_flags()['api_version'] < '2023.12': |
| 200 | + if device is not _default: |
| 201 | + raise ValueError("The device argument to from_dlpack requires at least version 2023.12 of the array API") |
| 202 | + if copy is not _default: |
| 203 | + raise ValueError("The copy argument to from_dlpack requires at least version 2023.12 of the array API") |
| 204 | + |
| 205 | + if device is not _default: |
| 206 | + _check_device(device) |
| 207 | + if copy not in [_default, None]: |
| 208 | + raise NotImplementedError("The copy argument to from_dlpack is not yet implemented") |
| 209 | + |
190 | 210 | return Array._new(np.from_dlpack(x))
|
191 | 211 |
|
192 | 212 |
|
|
0 commit comments