@@ -39,6 +39,30 @@ def set_memory_metadata(
39
39
utils .set_node_spec_attr (node , "vk_memory_layout" , layout )
40
40
41
41
42
+ def insert_transition_node (
43
+ graph_module : torch .fx .GraphModule ,
44
+ node : torch .fx .Node ,
45
+ arg : torch .fx .Node ,
46
+ storage : VkStorageType ,
47
+ layout : VkMemoryLayout ,
48
+ ) -> None :
49
+ """
50
+ Insert a clone node to copy the original tensor to a tensor with the desired storage
51
+ type and memory layout.
52
+ """
53
+ with graph_module .graph .inserting_before (node ):
54
+ clone_node = graph_module .graph .create_node (
55
+ "call_function" ,
56
+ exir_ops .edge .aten .clone .default ,
57
+ (arg ,),
58
+ )
59
+ clone_node .meta ["val" ] = arg .meta ["val" ]
60
+ clone_node .meta ["spec" ] = deepcopy (arg .meta ["spec" ])
61
+ clone_node .meta ["spec" ].const = False
62
+ set_memory_metadata (clone_node , storage , layout )
63
+ arg .replace_all_uses_with (clone_node , lambda x , y = node : x == y )
64
+
65
+
42
66
class TagMemoryMetaPass (ExportPass ):
43
67
"""
44
68
There are a variety of ways that tensors can be represented in Vulkan. The two main
@@ -174,14 +198,33 @@ def propose_node_layout(
174
198
else :
175
199
return next (iter (valid_layouts ))
176
200
201
+ def should_annotate (self , node ) -> bool :
202
+ if not isinstance (node , torch .fx .Node ):
203
+ return False
204
+
205
+ if not isinstance (node .meta ["val" ], FakeTensor ):
206
+ return False
207
+
208
+ # Storage type and memory layout for tensorref will be determined at runtime
209
+ # so there's no use in setting those attributes ahead of time.
210
+ if node .meta .get ("vkdg_tensorref" , False ):
211
+ return False
212
+
213
+ return True
214
+
215
+ def should_delay_annotation (self , node : torch .fx .Node ) -> bool :
216
+ # For prepack nodes, delay setting the storage type and memory layout as long as
217
+ # possible. This is to minimize the number of transitions, since it can be
218
+ # difficult to predict what storage type and memory layout should be used at the
219
+ # time the prepack node is observed.
220
+ return node .target == exir_ops .edge .et_vk .prepack .default
221
+
222
+ # noqa
177
223
def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
178
224
sorted_nodes : NodeList = topo_sort (list (graph_module .graph .nodes ))
179
225
180
226
for node in sorted_nodes :
181
- if not isinstance (node .meta ["val" ], FakeTensor ):
182
- continue
183
-
184
- if node .target == exir_ops .edge .et_vk .prepack .default :
227
+ if not self .should_annotate (node ) or self .should_delay_annotation (node ):
185
228
continue
186
229
187
230
storage = self .propose_node_storage (node )
@@ -191,11 +234,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
191
234
192
235
inserting_transitions_for_node = False
193
236
for i , arg in enumerate (node .args ):
194
- if not isinstance (arg , torch .fx .Node ):
195
- continue
196
- if not isinstance (arg .meta ["val" ], FakeTensor ):
237
+ if not self .should_annotate (arg ):
197
238
continue
198
239
240
+ assert isinstance (arg , torch .fx .Node )
241
+
199
242
arg_storage = utils .get_node_storage_type (arg )
200
243
arg_layout = utils .get_node_memory_layout (arg )
201
244
@@ -215,22 +258,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
215
258
f"[Vulkan Delegate] Inserting transition(s) for { node .format_node ()} :"
216
259
)
217
260
261
+ insert_transition_node (graph_module , node , arg , storage , layout )
262
+
218
263
logger .info (
219
264
f" args { i } ({ arg } ): ({ arg_storage } , { arg_layout } ) -> ({ storage } , { layout } )"
220
265
)
221
266
222
- # Insert a clone node to copy the original tensor to a tensor with the
223
- # desired storage type and memory layout.
224
- with graph_module .graph .inserting_before (node ):
225
- clone_node = graph_module .graph .create_node (
226
- "call_function" ,
227
- exir_ops .edge .aten .clone .default ,
228
- (arg ,),
229
- )
230
- clone_node .meta ["val" ] = arg .meta ["val" ]
231
- clone_node .meta ["spec" ] = deepcopy (arg .meta ["spec" ])
232
- clone_node .meta ["spec" ].const = False
233
- set_memory_metadata (clone_node , storage , layout )
234
- arg .replace_all_uses_with (clone_node , lambda x , y = node : x == y )
235
-
236
267
return PassResult (graph_module , True )
0 commit comments