@@ -258,13 +258,42 @@ def lift_constant_tensor_pass(ep):
258
258
buffers = list (graph_signature .buffers )
259
259
260
260
fake_mode = list (ep .graph .nodes )[0 ].meta ["val" ].fake_mode
261
- first_user_input = None
261
+ insert_before_node = None
262
262
lifted_constants = []
263
263
for node in ep .graph .nodes :
264
264
if node .op == "placeholder" and node .name in graph_signature .user_inputs :
265
- first_user_input = node
265
+ insert_before_node = node # first user input
266
266
break
267
267
268
+ if insert_before_node is None :
269
+ # we have no user inputs, find the node after the last buffer
270
+ # (that we will insert the lifted constants before).
271
+ # this is a bit hacky, but I am not certain of what the contract is for
272
+ # node ordering. is the first non-placeholder node guranteed to be the
273
+ # first node after input paramters? what if there is no op, and it is
274
+ # just placeholders? Easier to just find the last buffer, and insert after.
275
+
276
+ # also error if we have no buffers and no user inputs... if that is an issue, fix it later?
277
+ last_buffer = None
278
+ for node in ep .graph .nodes :
279
+ node_buffer_fqn = graph_signature .inputs_to_buffers .get (node .name , None )
280
+ # not sure if both cases are needed, if is it possible to encounter a
281
+ # buffer that is not a user input?
282
+ if (
283
+ node_buffer_fqn is not None
284
+ and node_buffer_fqn in graph_signature .buffers
285
+ ):
286
+ last_buffer = node
287
+ continue
288
+ if node .op == "placeholder" and node .name in graph_signature .buffers :
289
+ last_buffer = node
290
+ continue
291
+ # we have our last buffer, grab the node after it, to insert the lifted constants before.
292
+ insert_before_node = last_buffer .next
293
+
294
+ if insert_before_node is None :
295
+ raise ValueError ("No user inputs and no buffers found. Cannot lift constants." )
296
+
268
297
for node in ep .graph .nodes :
269
298
if node .op == "get_attr" :
270
299
constant_tensor = getattr (ep .graph_module , node .target )
@@ -273,7 +302,7 @@ def lift_constant_tensor_pass(ep):
273
302
274
303
constant_tensor_fqn = f"_lifted_tensor_constant{ len (buffers )} "
275
304
276
- with ep .graph .inserting_before (first_user_input ):
305
+ with ep .graph .inserting_before (insert_before_node ):
277
306
# Insert the constant node before the first user input
278
307
const_placeholder_node = ep .graph .placeholder (constant_tensor_fqn )
279
308
for k , v in node .meta .items ():
@@ -306,6 +335,9 @@ def lift_constant_tensor_pass(ep):
306
335
new_input_specs .extend (lifted_constants )
307
336
lifted_constants .clear ()
308
337
new_input_specs .append (s )
338
+ # Add remaining lifted constants if no user inputs exist.
339
+ if len (lifted_constants ) > 0 :
340
+ new_input_specs .extend (lifted_constants )
309
341
ep .graph_signature .input_specs = new_input_specs
310
342
ep .graph_module .recompile ()
311
343
return ep
0 commit comments