Skip to content

Commit 8ca03d2

Browse files
committed
Add missing names to torch.linalg
1 parent c55eb43 commit 8ca03d2

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

array_api_compat/cupy/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,7 @@
66
# These imports may overwrite names from the import * above.
77
from ._aliases import *
88

9-
# Don't know why, but we have to do an absolute import to import linalg. If we
10-
# instead do
11-
#
12-
# from . import linalg
13-
#
14-
# It doesn't overwrite cupy.linalg from above. The import is generated
15-
# dynamically so that the library can be vendored.
9+
# See the comment in the numpy __init__.py
1610
__import__(__package__ + '.linalg')
1711

1812
from .linalg import matrix_transpose, vecdot

array_api_compat/torch/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
# These imports may overwrite names from the import * above.
1515
from ._aliases import *
1616

17+
# See the comment in the numpy __init__.py
18+
__import__(__package__ + '.linalg')
19+
1720
from ..common._helpers import *
1821

1922
__array_api_version__ = '2021.12'

array_api_compat/torch/linalg.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,15 @@
1-
raise ImportError("The array api compat torch.linalg module extension is not yet implemented")
1+
from torch.linalg import *
2+
3+
# torch.linalg doesn't define __all__
4+
# from torch.linalg import __all__ as linalg_all
5+
from torch import linalg as _linalg
6+
linalg_all = [i for i in dir(_linalg) if not i.startswith('_')]
7+
8+
# These are implemented in torch but aren't in the linalg namespace
9+
from torch import outer, trace
10+
from ._aliases import matrix_transpose, tensordot
11+
12+
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot']
13+
14+
del linalg_all
15+
del _linalg

0 commit comments

Comments
 (0)