-
Notifications
You must be signed in to change notification settings - Fork 11
MAINT: Array API 2024.12 typing nits #156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
bf67bb8
36a422f
b67865f
123b11a
d43d8c8
35783d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
import warnings | ||
from collections.abc import Sequence | ||
from types import ModuleType | ||
from typing import cast | ||
from typing import TYPE_CHECKING, cast | ||
|
||
from ._at import at | ||
from ._utils import _compat, _helpers | ||
|
@@ -375,8 +375,8 @@ def expand_dims( | |
|
||
|
||
def isclose( | ||
a: Array, | ||
b: Array, | ||
a: Array | complex, | ||
b: Array | complex, | ||
*, | ||
rtol: float = 1e-05, | ||
atol: float = 1e-08, | ||
|
@@ -385,6 +385,9 @@ def isclose( | |
) -> Array: # numpydoc ignore=PR01,RT01 | ||
"""See docstring in array_api_extra._delegation.""" | ||
a, b = asarrays(a, b, xp=xp) | ||
if TYPE_CHECKING: # Hack around pyright bug # pragma: no cover | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the pyright bug? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. given def f(x: Array | complex) -> Array: ...
x: Array | complex
x = f(x) when you reach the end of the snippet, mypy correctly understands that x can only be Array, whereas pyright still believes it could be complex. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, so is it about reusing variable names? Is there an upstream issue we can link to? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. given the upstream answer, what would you like to do here? Use different variable names? Keep the current workarounds but remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've replaced |
||
assert _compat.is_array_api_obj(a) | ||
assert _compat.is_array_api_obj(b) | ||
|
||
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating")) | ||
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating")) | ||
|
@@ -419,7 +422,13 @@ def isclose( | |
return xp.abs(a - b) <= (atol + xp.abs(b) // nrtol) | ||
|
||
|
||
def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array: | ||
def kron( | ||
a: Array | complex, | ||
b: Array | complex, | ||
/, | ||
*, | ||
xp: ModuleType | None = None, | ||
) -> Array: | ||
""" | ||
Kronecker product of two arrays. | ||
|
||
|
@@ -495,9 +504,14 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array: | |
if xp is None: | ||
xp = array_namespace(a, b) | ||
a, b = asarrays(a, b, xp=xp) | ||
if TYPE_CHECKING: # Hack around pyright bug # pragma: no cover | ||
assert _compat.is_array_api_obj(a) | ||
assert _compat.is_array_api_obj(b) | ||
|
||
singletons = (1,) * (b.ndim - a.ndim) | ||
a = xp.broadcast_to(a, singletons + a.shape) | ||
if TYPE_CHECKING: # Hack around pyright bug # pragma: no cover | ||
assert _compat.is_array_api_obj(a) | ||
|
||
nd_b, nd_a = b.ndim, a.ndim | ||
nd_max = max(nd_b, nd_a) | ||
|
@@ -614,8 +628,8 @@ def pad( | |
|
||
|
||
def setdiff1d( | ||
x1: Array, | ||
x2: Array, | ||
x1: Array | complex, | ||
x2: Array | complex, | ||
/, | ||
*, | ||
assume_unique: bool = False, | ||
|
@@ -628,7 +642,7 @@ def setdiff1d( | |
|
||
Parameters | ||
---------- | ||
x1 : array | ||
x1 : array | int | float | complex | bool | ||
Input array. | ||
x2 : array | ||
Input comparison array. | ||
|
@@ -665,6 +679,10 @@ def setdiff1d( | |
else: | ||
x1 = xp.unique_values(x1) | ||
x2 = xp.unique_values(x2) | ||
|
||
if TYPE_CHECKING: # Hack around pyright bug # pragma: no cover | ||
assert _compat.is_array_api_obj(x1) | ||
|
||
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)] | ||
|
||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.