Skip to content

Commit a563d8a

Browse files
WIP need to clean up functions.
1 parent 06d27e7 commit a563d8a

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

exir/memory_planning.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from executorch.exir.tensor import TensorSpec
2525

2626
from torch import fx
27-
from torch.export.exported_program import ExportGraphSignature
27+
from torch.export.exported_program import ExportGraphSignature, InputKind
2828
from torch.fx import Node
2929
from torch.utils._pytree import tree_flatten
3030

@@ -247,7 +247,10 @@ def verify_graph_input_output(self) -> None:
247247
graph_output_allocated = allocated
248248
has_dynamic_unbound_output |= has_dynamic_unbound_tensor
249249

250-
if "placeholder" in check_list:
250+
# only check if inputs are allocated if there are user inputs:
251+
user_inputs_exist = len(list(filter(lambda input: input.kind == InputKind.USER_INPUT, self.graph_signature.input_specs))) > 0
252+
253+
if "placeholder" in check_list and user_inputs_exist:
251254
assert graph_input_allocated is not None, "graph_input_allocated not set"
252255
if not has_dynamic_unbound_input:
253256
assert (

exir/program/_program.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,13 +268,42 @@ def lift_constant_tensor_pass(ep):
268268
buffers = list(graph_signature.buffers)
269269

270270
fake_mode = list(ep.graph.nodes)[0].meta["val"].fake_mode
271-
first_user_input = None
271+
insert_before_node = None
272272
lifted_constants = []
273273
for node in ep.graph.nodes:
274274
if node.op == "placeholder" and node.name in graph_signature.user_inputs:
275-
first_user_input = node
275+
insert_before_node = node # first user input
276276
break
277277

278+
if insert_before_node is None:
279+
# we have no user inputs, find the node after the last buffer
280+
# (that we will insert the lifted constants before).
281+
# this is a bit hacky, but I am not certain of what the contract is for
282+
# node ordering. is the first non-placeholder node guranteed to be the
283+
# first node after input paramters? what if there is no op, and it is
284+
# just placeholders? Easier to just find the last buffer, and insert after.
285+
286+
# also error if we have no buffers and no user inputs... if that is an issue, fix it later?
287+
last_buffer = None
288+
for node in ep.graph.nodes:
289+
node_buffer_fqn = graph_signature.inputs_to_buffers.get(node.name, None)
290+
# not sure if both cases are needed, if is it possible to encounter a
291+
# buffer that is not a user input?
292+
if (
293+
node_buffer_fqn is not None
294+
and node_buffer_fqn in graph_signature.buffers
295+
):
296+
last_buffer = node
297+
continue
298+
if node.op == "placeholder" and node.name in graph_signature.buffers:
299+
last_buffer = node
300+
continue
301+
# we have our last buffer, grab the node after it, to insert the lifted constants before.
302+
insert_before_node = last_buffer.next
303+
304+
if insert_before_node is None:
305+
raise ValueError("No user inputs and no buffers found. Cannot lift constants.")
306+
278307
for node in ep.graph.nodes:
279308
if node.op == "get_attr":
280309
constant_tensor = getattr(ep.graph_module, node.target)
@@ -283,7 +312,7 @@ def lift_constant_tensor_pass(ep):
283312

284313
constant_tensor_fqn = f"_lifted_tensor_constant{len(buffers)}"
285314

286-
with ep.graph.inserting_before(first_user_input):
315+
with ep.graph.inserting_before(insert_before_node):
287316
# Insert the constant node before the first user input
288317
const_placeholder_node = ep.graph.placeholder(constant_tensor_fqn)
289318
for k, v in node.meta.items():
@@ -316,6 +345,9 @@ def lift_constant_tensor_pass(ep):
316345
new_input_specs.extend(lifted_constants)
317346
lifted_constants.clear()
318347
new_input_specs.append(s)
348+
# Add remaining lifted constants if no user inputs exist.
349+
if len(lifted_constants) > 0:
350+
new_input_specs.extend(lifted_constants)
319351
ep.graph_signature.input_specs = new_input_specs
320352
ep.graph_module.recompile()
321353
return ep

0 commit comments

Comments
 (0)