@@ -268,42 +268,13 @@ def lift_constant_tensor_pass(ep):
268
268
buffers = list (graph_signature .buffers )
269
269
270
270
fake_mode = list (ep .graph .nodes )[0 ].meta ["val" ].fake_mode
271
- insert_before_node = None
271
+ first_user_input = None
272
272
lifted_constants = []
273
273
for node in ep .graph .nodes :
274
274
if node .op == "placeholder" and node .name in graph_signature .user_inputs :
275
- insert_before_node = node # first user input
275
+ first_user_input = node
276
276
break
277
277
278
- if insert_before_node is None :
279
- # we have no user inputs, find the node after the last buffer
280
- # (that we will insert the lifted constants before).
281
- # this is a bit hacky, but I am not certain of what the contract is for
282
- # node ordering. is the first non-placeholder node guranteed to be the
283
- # first node after input paramters? what if there is no op, and it is
284
- # just placeholders? Easier to just find the last buffer, and insert after.
285
-
286
- # also error if we have no buffers and no user inputs... if that is an issue, fix it later?
287
- last_buffer = None
288
- for node in ep .graph .nodes :
289
- node_buffer_fqn = graph_signature .inputs_to_buffers .get (node .name , None )
290
- # not sure if both cases are needed, if is it possible to encounter a
291
- # buffer that is not a user input?
292
- if (
293
- node_buffer_fqn is not None
294
- and node_buffer_fqn in graph_signature .buffers
295
- ):
296
- last_buffer = node
297
- continue
298
- if node .op == "placeholder" and node .name in graph_signature .buffers :
299
- last_buffer = node
300
- continue
301
- # we have our last buffer, grab the node after it, to insert the lifted constants before.
302
- insert_before_node = last_buffer .next
303
-
304
- if insert_before_node is None :
305
- raise ValueError ("No user inputs and no buffers found. Cannot lift constants." )
306
-
307
278
for node in ep .graph .nodes :
308
279
if node .op == "get_attr" :
309
280
constant_tensor = getattr (ep .graph_module , node .target )
@@ -312,7 +283,7 @@ def lift_constant_tensor_pass(ep):
312
283
313
284
constant_tensor_fqn = f"_lifted_tensor_constant{ len (buffers )} "
314
285
315
- with ep .graph .inserting_before (insert_before_node ):
286
+ with ep .graph .inserting_before (first_user_input ):
316
287
# Insert the constant node before the first user input
317
288
const_placeholder_node = ep .graph .placeholder (constant_tensor_fqn )
318
289
for k , v in node .meta .items ():
@@ -345,9 +316,8 @@ def lift_constant_tensor_pass(ep):
345
316
new_input_specs .extend (lifted_constants )
346
317
lifted_constants .clear ()
347
318
new_input_specs .append (s )
348
- # Add remaining lifted constants if no user inputs exist.
349
319
if len (lifted_constants ) > 0 :
350
- new_input_specs . extend ( lifted_constants )
320
+ new_input_specs = lifted_constants + new_input_specs
351
321
ep .graph_signature .input_specs = new_input_specs
352
322
ep .graph_module .recompile ()
353
323
return ep
0 commit comments