-
Notifications
You must be signed in to change notification settings - Fork 35
TYP: Type annotations, part 4 #313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 15 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
362c48a
Type annotations, part 4
crusaderky ad375dc
Fix CopyMode
crusaderky 49f9ba7
revert
crusaderky 4371506
Merge branch 'main' into typ_v4
crusaderky c724a52
Revert `_all_ignore`
crusaderky 14f70af
code review
crusaderky 0a571bc
code review
crusaderky 0172300
JustInt mypy ignores
crusaderky 8711041
Merge branch 'main' into typ_v4
crusaderky 014e20f
lint
crusaderky 7c5408c
Merge branch 'main' into typ_v4
crusaderky 924fc3d
fix merge
crusaderky 5d98aa8
Merge branch 'main' into typ_v4
crusaderky 8eb647f
lint
crusaderky a06d51f
Merge branch 'main' into typ_v4
crusaderky 247ee6d
Reverts and tweaks
crusaderky 85fce08
Fix test_all
crusaderky 983296f
Revert batmobile
crusaderky d81b3aa
Merge branch 'main' into typ_v4
crusaderky 2954efd
Merge branch 'main' into typ_v4
crusaderky c244872
Merge branch 'main' into typ_v4
crusaderky File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -12,50 +12,46 @@ | |||||
import math | ||||||
import sys | ||||||
import warnings | ||||||
from collections.abc import Collection, Hashable | ||||||
from collections.abc import Hashable | ||||||
from functools import lru_cache | ||||||
from types import NoneType | ||||||
from typing import ( | ||||||
TYPE_CHECKING, | ||||||
Any, | ||||||
Final, | ||||||
Literal, | ||||||
SupportsIndex, | ||||||
TypeAlias, | ||||||
TypeGuard, | ||||||
TypeVar, | ||||||
cast, | ||||||
overload, | ||||||
) | ||||||
|
||||||
from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
|
||||||
import cupy as cp | ||||||
import dask.array as da | ||||||
import jax | ||||||
import ndonnx as ndx | ||||||
import numpy as np | ||||||
import numpy.typing as npt | ||||||
import sparse # pyright: ignore[reportMissingTypeStubs] | ||||||
import sparse | ||||||
import torch | ||||||
|
||||||
# TODO: import from typing (requires Python >=3.13) | ||||||
from typing_extensions import TypeIs, TypeVar | ||||||
|
||||||
_SizeT = TypeVar("_SizeT", bound = int | None) | ||||||
from typing_extensions import TypeIs | ||||||
|
||||||
_ZeroGradientArray: TypeAlias = npt.NDArray[np.void] | ||||||
_CupyArray: TypeAlias = Any # cupy has no py.typed | ||||||
|
||||||
_ArrayApiObj: TypeAlias = ( | ||||||
npt.NDArray[Any] | ||||||
| cp.ndarray | ||||||
| da.Array | ||||||
| jax.Array | ||||||
| ndx.Array | ||||||
| sparse.SparseArray | ||||||
| torch.Tensor | ||||||
| SupportsArrayNamespace[Any] | ||||||
| _CupyArray | ||||||
| SupportsArrayNamespace | ||||||
) | ||||||
|
||||||
_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"}) | ||||||
|
@@ -95,7 +91,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: | |||||
return dtype == jax.float0 | ||||||
|
||||||
|
||||||
def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: | ||||||
def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]: | ||||||
""" | ||||||
Return True if `x` is a NumPy array. | ||||||
|
||||||
|
@@ -266,7 +262,7 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: | |||||
return _issubclass_fast(cls, "sparse", "SparseArray") | ||||||
|
||||||
|
||||||
def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] | ||||||
def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]: | ||||||
jorenham marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
""" | ||||||
Return True if `x` is an array API compatible array object. | ||||||
|
||||||
|
@@ -581,7 +577,7 @@ def your_function(x, y): | |||||
|
||||||
namespaces.add(cupy_namespace) | ||||||
else: | ||||||
import cupy as cp # pyright: ignore[reportMissingTypeStubs] | ||||||
import cupy as cp | ||||||
|
||||||
namespaces.add(cp) | ||||||
elif is_torch_array(x): | ||||||
|
@@ -618,14 +614,14 @@ def your_function(x, y): | |||||
if hasattr(jax.numpy, "__array_api_version__"): | ||||||
jnp = jax.numpy | ||||||
else: | ||||||
import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports] | ||||||
import jax.experimental.array_api as jnp # type: ignore[no-redef] | ||||||
namespaces.add(jnp) | ||||||
elif is_pydata_sparse_array(x): | ||||||
if use_compat is True: | ||||||
_check_api_version(api_version) | ||||||
raise ValueError("`sparse` does not have an array-api-compat wrapper") | ||||||
else: | ||||||
import sparse # pyright: ignore[reportMissingTypeStubs] | ||||||
import sparse | ||||||
jorenham marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# `sparse` is already an array namespace. We do not have a wrapper | ||||||
# submodule for it. | ||||||
namespaces.add(sparse) | ||||||
|
@@ -634,9 +630,9 @@ def your_function(x, y): | |||||
raise ValueError( | ||||||
"The given array does not have an array-api-compat wrapper" | ||||||
) | ||||||
x = cast("SupportsArrayNamespace[Any]", x) | ||||||
x = cast(SupportsArrayNamespace, x) | ||||||
jorenham marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
namespaces.add(x.__array_namespace__(api_version=api_version)) | ||||||
elif isinstance(x, (bool, int, float, complex, type(None))): | ||||||
elif isinstance(x, int | float | complex | NoneType): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
(I'll spare you the pseudo-philosophical rant this time) |
||||||
continue | ||||||
else: | ||||||
# TODO: Support Python scalars? | ||||||
|
@@ -732,7 +728,7 @@ def device(x: _ArrayApiObj, /) -> Device: | |||||
return "cpu" | ||||||
elif is_dask_array(x): | ||||||
# Peek at the metadata of the Dask array to determine type | ||||||
if is_numpy_array(x._meta): # pyright: ignore | ||||||
if is_numpy_array(x._meta): | ||||||
# Must be on CPU since backed by numpy | ||||||
return "cpu" | ||||||
return _DASK_DEVICE | ||||||
|
@@ -761,7 +757,7 @@ def device(x: _ArrayApiObj, /) -> Device: | |||||
return "cpu" | ||||||
# Return the device of the constituent array | ||||||
return device(inner) # pyright: ignore | ||||||
return x.device # pyright: ignore | ||||||
return x.device # type: ignore # pyright: ignore | ||||||
|
||||||
|
||||||
# Prevent shadowing, used below | ||||||
|
@@ -770,11 +766,11 @@ def device(x: _ArrayApiObj, /) -> Device: | |||||
|
||||||
# Based on cupy.array_api.Array.to_device | ||||||
def _cupy_to_device( | ||||||
x: _CupyArray, | ||||||
x: cp.ndarray, | ||||||
device: Device, | ||||||
/, | ||||||
stream: int | Any | None = None, | ||||||
) -> _CupyArray: | ||||||
) -> cp.ndarray: | ||||||
jorenham marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
import cupy as cp | ||||||
|
||||||
if device == "cpu": | ||||||
|
@@ -803,7 +799,7 @@ def _torch_to_device( | |||||
x: torch.Tensor, | ||||||
device: torch.device | str | int, | ||||||
/, | ||||||
stream: None = None, | ||||||
stream: int | Any | None = None, | ||||||
jorenham marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
) -> torch.Tensor: | ||||||
if stream is not None: | ||||||
raise NotImplementedError | ||||||
|
@@ -869,7 +865,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) - | |||||
# cupy does not yet have to_device | ||||||
return _cupy_to_device(x, device, stream=stream) | ||||||
elif is_torch_array(x): | ||||||
return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType] | ||||||
return _torch_to_device(x, device, stream=stream) | ||||||
elif is_dask_array(x): | ||||||
if stream is not None: | ||||||
raise ValueError("The stream argument to to_device() is not supported") | ||||||
|
@@ -894,12 +890,12 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) - | |||||
|
||||||
|
||||||
@overload | ||||||
def size(x: HasShape[Collection[SupportsIndex]]) -> int: ... | ||||||
def size(x: HasShape[int]) -> int: ... | ||||||
@overload | ||||||
def size(x: HasShape[Collection[None]]) -> None: ... | ||||||
def size(x: HasShape[int | None]) -> int | None: ... | ||||||
@overload | ||||||
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ... | ||||||
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: | ||||||
def size(x: HasShape[float]) -> int | None: ... # Dask special case | ||||||
def size(x: HasShape[float | None]) -> int | None: | ||||||
jorenham marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
""" | ||||||
Return the total number of elements of x. | ||||||
|
||||||
|
@@ -914,9 +910,9 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: | |||||
# Lazy API compliant arrays, such as ndonnx, can contain None in their shape | ||||||
if None in x.shape: | ||||||
return None | ||||||
out = math.prod(cast("Collection[SupportsIndex]", x.shape)) | ||||||
out = math.prod(cast(tuple[float, ...], x.shape)) | ||||||
# dask.array.Array.shape can contain NaN | ||||||
return None if math.isnan(out) else out | ||||||
return None if math.isnan(out) else cast(int, out) | ||||||
|
||||||
|
||||||
@lru_cache(100) | ||||||
|
@@ -932,7 +928,7 @@ def _is_writeable_cls(cls: type) -> bool | None: | |||||
return None | ||||||
|
||||||
|
||||||
def is_writeable_array(x: object) -> bool: | ||||||
def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]: | ||||||
""" | ||||||
Return False if ``x.__setitem__`` is expected to raise; True otherwise. | ||||||
Return False if `x` is not an array API compatible object. | ||||||
|
@@ -970,7 +966,7 @@ def _is_lazy_cls(cls: type) -> bool | None: | |||||
return None | ||||||
|
||||||
|
||||||
def is_lazy_array(x: object) -> bool: | ||||||
def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]: | ||||||
jorenham marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
"""Return True if x is potentially a future or it may be otherwise impossible or | ||||||
expensive to eagerly read its contents, regardless of their size, e.g. by | ||||||
calling ``bool(x)`` or ``float(x)``. | ||||||
|
@@ -1007,7 +1003,7 @@ def is_lazy_array(x: object) -> bool: | |||||
# on __bool__ (dask is one such example, which however is special-cased above). | ||||||
|
||||||
# Select a single point of the array | ||||||
s = size(cast("HasShape[Collection[SupportsIndex | None]]", x)) | ||||||
s = size(cast(HasShape, x)) | ||||||
if s is None: | ||||||
return True | ||||||
xp = array_namespace(x) | ||||||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.