Skip to content

Commit bb07bb0

Browse files
WIP need to clean up functions.
1 parent 7bc06d1 commit bb07bb0

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

exir/memory_planning.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from executorch.exir.tensor import TensorSpec
2525

2626
from torch import fx
27-
from torch.export.exported_program import ExportGraphSignature
27+
from torch.export.exported_program import ExportGraphSignature, InputKind
2828
from torch.fx import Node
2929
from torch.utils._pytree import tree_flatten
3030

@@ -236,7 +236,10 @@ def verify_graph_input_output(self) -> None:
236236
graph_output_allocated = allocated
237237
has_dynamic_unbound_output |= has_dynamic_unbound_tensor
238238

239-
if "placeholder" in check_list:
239+
# only check if inputs are allocated if there are user inputs:
240+
user_inputs_exist = len(list(filter(lambda input: input.kind == InputKind.USER_INPUT, self.graph_signature.input_specs))) > 0
241+
242+
if "placeholder" in check_list and user_inputs_exist:
240243
assert graph_input_allocated is not None, "graph_input_allocated not set"
241244
if not has_dynamic_unbound_input:
242245
assert (

exir/program/_program.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,42 @@ def lift_constant_tensor_pass(ep):
258258
buffers = list(graph_signature.buffers)
259259

260260
fake_mode = list(ep.graph.nodes)[0].meta["val"].fake_mode
261-
first_user_input = None
261+
insert_before_node = None
262262
lifted_constants = []
263263
for node in ep.graph.nodes:
264264
if node.op == "placeholder" and node.name in graph_signature.user_inputs:
265-
first_user_input = node
265+
insert_before_node = node # first user input
266266
break
267267

268+
if insert_before_node is None:
269+
# we have no user inputs, find the node after the last buffer
270+
# (that we will insert the lifted constants before).
271+
# this is a bit hacky, but I am not certain of what the contract is for
272+
# node ordering. is the first non-placeholder node guranteed to be the
273+
# first node after input paramters? what if there is no op, and it is
274+
# just placeholders? Easier to just find the last buffer, and insert after.
275+
276+
# also error if we have no buffers and no user inputs... if that is an issue, fix it later?
277+
last_buffer = None
278+
for node in ep.graph.nodes:
279+
node_buffer_fqn = graph_signature.inputs_to_buffers.get(node.name, None)
280+
# not sure if both cases are needed, if is it possible to encounter a
281+
# buffer that is not a user input?
282+
if (
283+
node_buffer_fqn is not None
284+
and node_buffer_fqn in graph_signature.buffers
285+
):
286+
last_buffer = node
287+
continue
288+
if node.op == "placeholder" and node.name in graph_signature.buffers:
289+
last_buffer = node
290+
continue
291+
# we have our last buffer, grab the node after it, to insert the lifted constants before.
292+
insert_before_node = last_buffer.next
293+
294+
if insert_before_node is None:
295+
raise ValueError("No user inputs and no buffers found. Cannot lift constants.")
296+
268297
for node in ep.graph.nodes:
269298
if node.op == "get_attr":
270299
constant_tensor = getattr(ep.graph_module, node.target)
@@ -273,7 +302,7 @@ def lift_constant_tensor_pass(ep):
273302

274303
constant_tensor_fqn = f"_lifted_tensor_constant{len(buffers)}"
275304

276-
with ep.graph.inserting_before(first_user_input):
305+
with ep.graph.inserting_before(insert_before_node):
277306
# Insert the constant node before the first user input
278307
const_placeholder_node = ep.graph.placeholder(constant_tensor_fqn)
279308
for k, v in node.meta.items():
@@ -306,6 +335,9 @@ def lift_constant_tensor_pass(ep):
306335
new_input_specs.extend(lifted_constants)
307336
lifted_constants.clear()
308337
new_input_specs.append(s)
338+
# Add remaining lifted constants if no user inputs exist.
339+
if len(lifted_constants) > 0:
340+
new_input_specs.extend(lifted_constants)
309341
ep.graph_signature.input_specs = new_input_specs
310342
ep.graph_module.recompile()
311343
return ep

0 commit comments

Comments
 (0)