Skip to content

Commit 184de15

Browse files
committed
Fix generate_stubs.py to correctly find annotations for functions with underscores
1 parent 8926ce7 commit 184de15

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

array_api_tests/function_stubs/linalg.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from __future__ import annotations
1212

1313
from ._types import Literal, Optional, Tuple, Union, array
14+
from .constants import inf
1415
from collections.abc import Sequence
1516

1617
def cholesky(x: array, /, *, upper: bool = False) -> array:
@@ -46,7 +47,7 @@ def inv(x: array, /) -> array:
4647
def matmul(x1: array, x2: array, /) -> array:
4748
pass
4849

49-
def matrix_norm(x, /, *, axis=(-2, -1), keepdims=False, ord='fro'):
50+
def matrix_norm(x: array, /, *, axis: Tuple[int, int] = (-2, -1), keepdims: bool = False, ord: Optional[Union[int, float, Literal[inf, -inf, 'fro', 'nuc']]] = 'fro') -> array:
5051
pass
5152

5253
def matrix_power(x: array, n: int, /) -> array:
@@ -55,7 +56,7 @@ def matrix_power(x: array, n: int, /) -> array:
5556
def matrix_rank(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array:
5657
pass
5758

58-
def matrix_transpose(x, /):
59+
def matrix_transpose(x: array, /) -> array:
5960
pass
6061

6162
def outer(x1: array, x2: array, /) -> array:
@@ -88,7 +89,7 @@ def trace(x: array, /, *, offset: int = 0) -> array:
8889
def vecdot(x1: array, x2: array, /, *, axis: Optional[int] = None) -> array:
8990
pass
9091

91-
def vector_norm(x, /, *, axis=None, keepdims=False, ord=2):
92+
def vector_norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float, Literal[inf, -inf]]] = 2) -> array:
9293
pass
9394

9495
__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eig', 'eigh', 'eigvals', 'eigvalsh', 'einsum', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm']

array_api_tests/function_stubs/linear_algebra_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def einsum():
1919
def matmul(x1: array, x2: array, /) -> array:
2020
pass
2121

22-
def matrix_transpose(x, /):
22+
def matrix_transpose(x: array, /) -> array:
2323
pass
2424

2525
def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> array:

generate_stubs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -690,15 +690,15 @@ def parse_annotations(spec_text, all_annotations, verbose=False):
690690
for line in spec_text.splitlines():
691691
m = HEADER_RE.match(line)
692692
if m:
693-
name = m.group(1)
693+
name = m.group(1).replace('-', '_')
694694
continue
695695
m = ALIAS_RE.match(line)
696696
if m:
697-
alias_name = m.group(1)
697+
alias_name = m.group(1).replace('-', '_')
698698
if alias_name not in all_annotations:
699699
print(f"Warning: No annotations for aliased function {name}")
700700
else:
701-
annotations[name] = all_annotations[m.group(1)]
701+
annotations[name] = all_annotations[alias_name]
702702
continue
703703
if line == '#### Parameters':
704704
in_block = True

0 commit comments

Comments
 (0)