Skip to content

Commit c447516

Browse files
committed
Fix final ruff errors in array_api_compat/torch/__init__.py
1 parent fb447e6 commit c447516

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

array_api_compat/torch/__init__.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,28 @@
22
import torch
33
from torch import * # noqa: F401, F403
44

5-
for n in dir(torch):
5+
from .._internal import _get_all_public_members
6+
7+
8+
def filter_(name):
69
if (
7-
n.startswith("_")
8-
or n.endswith("_")
9-
or "cuda" in n
10-
or "cpu" in n
11-
or "backward" in n
10+
name.startswith("_")
11+
or name.endswith("_")
12+
or "cuda" in name
13+
or "cpu" in name
14+
or "backward" in name
1215
):
13-
continue
14-
exec(n + " = torch." + n)
16+
return False
17+
return True
18+
19+
20+
_torch_all = _get_all_public_members(torch, filter_=filter_)
1521

22+
for _name in _torch_all:
23+
globals()[_name] = getattr(torch, _name)
1624

17-
from ..common._helpers import (
25+
26+
from ..common._helpers import ( # noqa: E402
1827
array_namespace,
1928
device,
2029
get_namespace,
@@ -24,7 +33,7 @@
2433
)
2534

2635
# These imports may overwrite names from the import * above.
27-
from ._aliases import (
36+
from ._aliases import ( # noqa: E402
2837
add,
2938
all,
3039
any,
@@ -92,7 +101,11 @@
92101
zeros,
93102
)
94103

95-
__all__ = [
104+
__all__ = []
105+
106+
__all__ += _torch_all
107+
108+
__all__ += [
96109
"is_array_api_obj",
97110
"array_namespace",
98111
"get_namespace",

0 commit comments

Comments
 (0)