File tree Expand file tree Collapse file tree 1 file changed +24
-11
lines changed Expand file tree Collapse file tree 1 file changed +24
-11
lines changed Original file line number Diff line number Diff line change 2
2
import torch
3
3
from torch import * # noqa: F401, F403
4
4
5
- for n in dir (torch ):
5
+ from .._internal import _get_all_public_members
6
+
7
+
8
+ def filter_ (name ):
6
9
if (
7
- n .startswith ("_" )
8
- or n .endswith ("_" )
9
- or "cuda" in n
10
- or "cpu" in n
11
- or "backward" in n
10
+ name .startswith ("_" )
11
+ or name .endswith ("_" )
12
+ or "cuda" in name
13
+ or "cpu" in name
14
+ or "backward" in name
12
15
):
13
- continue
14
- exec (n + " = torch." + n )
16
+ return False
17
+ return True
18
+
19
+
20
+ _torch_all = _get_all_public_members (torch , filter_ = filter_ )
15
21
22
+ for _name in _torch_all :
23
+ globals ()[_name ] = getattr (torch , _name )
16
24
17
- from ..common ._helpers import (
25
+
26
+ from ..common ._helpers import ( # noqa: E402
18
27
array_namespace ,
19
28
device ,
20
29
get_namespace ,
24
33
)
25
34
26
35
# These imports may overwrite names from the import * above.
27
- from ._aliases import (
36
+ from ._aliases import ( # noqa: E402
28
37
add ,
29
38
all ,
30
39
any ,
92
101
zeros ,
93
102
)
94
103
95
- __all__ = [
104
+ __all__ = []
105
+
106
+ __all__ += _torch_all
107
+
108
+ __all__ += [
96
109
"is_array_api_obj" ,
97
110
"array_namespace" ,
98
111
"get_namespace" ,
You can’t perform that action at this time.
0 commit comments