16
16
from torch .fx import Node
17
17
18
18
_ScalarType = Union [int , bool , float ]
19
- _Argument = Union [torch . fx . Node , int , bool , float , str ]
19
+ _Argument = Union [Node , int , bool , float , str ]
20
20
21
21
22
22
class VkGraphBuilder :
@@ -29,7 +29,7 @@ def __init__(self, program: ExportedProgram) -> None:
29
29
self .output_ids = []
30
30
self .const_tensors = []
31
31
32
- # Mapping from torch.fx. Node to VkValue id
32
+ # Mapping from Node to VkValue id
33
33
self .node_to_value_ids = {}
34
34
35
35
@staticmethod
@@ -39,18 +39,18 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
39
39
else :
40
40
raise AssertionError (f"Invalid dtype for vulkan_preprocess ({ torch_dtype } )" )
41
41
42
- def is_constant (self , node : torch . fx . Node ):
42
+ def is_constant (self , node : Node ):
43
43
return (
44
44
node .name in self .program .graph_signature .inputs_to_lifted_tensor_constants
45
45
)
46
46
47
- def is_get_attr_node (self , node : torch . fx . Node ) -> bool :
47
+ def is_get_attr_node (self , node : Node ) -> bool :
48
48
"""
49
49
Returns true if the given node is a get attr node for a tensor of the model
50
50
"""
51
- return isinstance (node , torch . fx . Node ) and node .op == "get_attr"
51
+ return isinstance (node , Node ) and node .op == "get_attr"
52
52
53
- def is_param_node (self , node : torch . fx . Node ) -> bool :
53
+ def is_param_node (self , node : Node ) -> bool :
54
54
"""
55
55
Check if the given node is a parameter within the exported program
56
56
"""
@@ -61,7 +61,7 @@ def is_param_node(self, node: torch.fx.Node) -> bool:
61
61
or self .is_constant (node )
62
62
)
63
63
64
- def get_constant (self , node : torch . fx . Node ) -> Optional [torch .Tensor ]:
64
+ def get_constant (self , node : Node ) -> Optional [torch .Tensor ]:
65
65
"""
66
66
Returns the constant associated with the given node in the exported program.
67
67
Returns None if the node is not a constant within the exported program
@@ -79,7 +79,7 @@ def get_constant(self, node: torch.fx.Node) -> Optional[torch.Tensor]:
79
79
80
80
return None
81
81
82
- def get_param_tensor (self , node : torch . fx . Node ) -> torch .Tensor :
82
+ def get_param_tensor (self , node : Node ) -> torch .Tensor :
83
83
tensor = None
84
84
if node is None :
85
85
raise RuntimeError ("node is None" )
@@ -168,7 +168,7 @@ def create_string_value(self, string: str) -> int:
168
168
return new_id
169
169
170
170
def get_or_create_value_for (self , arg : _Argument ):
171
- if isinstance (arg , torch . fx . Node ):
171
+ if isinstance (arg , Node ):
172
172
# If the value has already been created, return the existing id
173
173
if arg in self .node_to_value_ids :
174
174
return self .node_to_value_ids [arg ]
0 commit comments