Skip to content

Commit 16a3156

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Shorten torch.fx.Node to Node (#2403)
Summary: bypass-github-export-checks Pull Request resolved: #2403 It's too many characters to type out, and there was already a mix of switching between the two. Doing this in a separate change to simplify review. ghstack-source-id: 218520050 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D54859246 fbshipit-source-id: 4a5b57fe5c7ca760df42c9cef37d99f3794904e7
1 parent e807a75 commit 16a3156

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch.fx import Node
1717

1818
_ScalarType = Union[int, bool, float]
19-
_Argument = Union[torch.fx.Node, int, bool, float, str]
19+
_Argument = Union[Node, int, bool, float, str]
2020

2121

2222
class VkGraphBuilder:
@@ -29,7 +29,7 @@ def __init__(self, program: ExportedProgram) -> None:
2929
self.output_ids = []
3030
self.const_tensors = []
3131

32-
# Mapping from torch.fx.Node to VkValue id
32+
# Mapping from Node to VkValue id
3333
self.node_to_value_ids = {}
3434

3535
@staticmethod
@@ -39,18 +39,18 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
3939
else:
4040
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")
4141

42-
def is_constant(self, node: torch.fx.Node):
42+
def is_constant(self, node: Node):
4343
return (
4444
node.name in self.program.graph_signature.inputs_to_lifted_tensor_constants
4545
)
4646

47-
def is_get_attr_node(self, node: torch.fx.Node) -> bool:
47+
def is_get_attr_node(self, node: Node) -> bool:
4848
"""
4949
Returns true if the given node is a get attr node for a tensor of the model
5050
"""
51-
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
51+
return isinstance(node, Node) and node.op == "get_attr"
5252

53-
def is_param_node(self, node: torch.fx.Node) -> bool:
53+
def is_param_node(self, node: Node) -> bool:
5454
"""
5555
Check if the given node is a parameter within the exported program
5656
"""
@@ -61,7 +61,7 @@ def is_param_node(self, node: torch.fx.Node) -> bool:
6161
or self.is_constant(node)
6262
)
6363

64-
def get_constant(self, node: torch.fx.Node) -> Optional[torch.Tensor]:
64+
def get_constant(self, node: Node) -> Optional[torch.Tensor]:
6565
"""
6666
Returns the constant associated with the given node in the exported program.
6767
Returns None if the node is not a constant within the exported program
@@ -79,7 +79,7 @@ def get_constant(self, node: torch.fx.Node) -> Optional[torch.Tensor]:
7979

8080
return None
8181

82-
def get_param_tensor(self, node: torch.fx.Node) -> torch.Tensor:
82+
def get_param_tensor(self, node: Node) -> torch.Tensor:
8383
tensor = None
8484
if node is None:
8585
raise RuntimeError("node is None")
@@ -168,7 +168,7 @@ def create_string_value(self, string: str) -> int:
168168
return new_id
169169

170170
def get_or_create_value_for(self, arg: _Argument):
171-
if isinstance(arg, torch.fx.Node):
171+
if isinstance(arg, Node):
172172
# If the value has already been created, return the existing id
173173
if arg in self.node_to_value_ids:
174174
return self.node_to_value_ids[arg]

0 commit comments

Comments
 (0)