8
8
9
9
from copy import deepcopy
10
10
11
- import executorch .backends .vulkan .custom_ops_lib # noqa
12
-
13
- import torch
14
-
15
11
from executorch .backends .vulkan .op_registry import handles_own_prepacking
16
12
from executorch .backends .vulkan .utils import is_param_node
17
13
@@ -31,27 +27,27 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
31
27
argument into the operator implementation.
32
28
"""
33
29
34
- def prepack_not_required (node : torch .fx .Node ) -> bool :
30
+ for node in program .graph_module .graph .nodes :
31
+ # Prepacking is only needed for constant tensors. Only nodes corresponding to
32
+ # constant tensors will proceed beyond this point.
35
33
if not is_param_node (program , node ):
36
- return True
34
+ continue
37
35
38
- # Annotate that this node is going to represented as a tensorref in the Vulkan
39
- # compute graph. This will be useful for later graph passes.
36
+ # Mark that this node is going to be represented as a TensorRef type in the
37
+ # Vulkan compute graph. This annotation is used in later graph passes.
40
38
node .meta ["vkdg_tensorref" ] = True
41
39
40
+ # Get the list of node users that do not handle their own prepacking
41
+ nodes_to_replace_input = []
42
42
for user in node .users :
43
- if user .op == "call_function" and handles_own_prepacking (
44
- # pyre-ignore
45
- user .target
46
- ):
47
- return True
43
+ if user .op == "call_function" and not handles_own_prepacking (user .target ):
44
+ nodes_to_replace_input .append (user )
48
45
49
- return False
50
-
51
- for node in program .graph_module .graph .nodes :
52
- if prepack_not_required (node ):
46
+ if len (nodes_to_replace_input ) == 0 :
53
47
continue
54
48
49
+ replace_all_uses = len (nodes_to_replace_input ) == len (node .users )
50
+
55
51
with program .graph_module .graph .inserting_after (node ):
56
52
prepack_node = program .graph_module .graph .create_node (
57
53
"call_function" ,
@@ -74,9 +70,14 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
74
70
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
75
71
# memory object.
76
72
prepack_node .meta ["spec" ].mem_obj_id = - 1
77
- node .replace_all_uses_with (
78
- prepack_node , lambda x , y = prepack_node : (x != y and x .op != "output" )
79
- )
73
+ if replace_all_uses :
74
+ node .replace_all_uses_with (
75
+ prepack_node ,
76
+ lambda x , y = prepack_node : (x != y and x .op != "output" ),
77
+ )
78
+ else :
79
+ for user_node in nodes_to_replace_input :
80
+ user_node .replace_input_with (node , prepack_node )
80
81
81
82
program .graph .eliminate_dead_code ()
82
83
return program
0 commit comments