Skip to content

Commit 3272ad7

Browse files
angelayifacebook-github-bot
authored andcommitted
Change _generate_new_graph_signature (#206)
Summary: Pull Request resolved: #206 X-link: pytorch/pytorch#108571 Previously `_generate_new_graph_signature` had the assumption that all transformations were not in place. However, this is an incorrect assumption leading to mysterious failures when running passes doing in-place modifications. This function is technically only needed in the case where the user output node or user input node name is changed. For example, if the user output node was "add" but a pass changes all the "add"s to "mul"s, then the output node will now be named "mul", which we have to update. For cases where users change the number of user inputs/outputs, number of parameters/buffers, or the names of parameters/buffers it will require extra work on the user's side to update the graph signature, since there is no automatic way for us to detect where to put what. Note: this doesn't actually change the names for the buffers_to_mutate part of the graph signature, but we're going to assume this is rare, because implementing auto-fixing for that is a little hard... Reviewed By: digantdesai Differential Revision: D48917505 fbshipit-source-id: 53132c750ef4b0610fa0bb0fc4c944f4e6a2afc6
1 parent dd57cc2 commit 3272ad7

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

exir/capture/_capture.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,17 @@ def convert_to_fake(x):
204204
_instantiate_missing_placeholder_val_with_real_inputs(graph_module, flat_args)
205205
graph_module._apply(torch.Tensor.contiguous)
206206

207+
user_inputs = [
208+
node.name for node in graph_module.graph.nodes if node.op == "placeholder"
209+
]
210+
output_node = list(graph_module.graph.nodes)[-1]
211+
assert output_node.op == "output"
212+
user_outputs = [arg.name for arg in output_node.args[0]]
213+
207214
ep = ExportedProgram(
208215
graph_module,
209216
graph_module.graph,
210-
ExportGraphSignature([], [], [], [], {}, {}, {}, None),
217+
ExportGraphSignature([], [], user_inputs, user_outputs, {}, {}, {}, None),
211218
CallSpec(in_spec, out_spec),
212219
{},
213220
{},

exir/lowered_backend_module.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ def program(self, emit_stacktrace: bool = False) -> Program:
220220
lowered_exported_program.graph_signature.user_inputs = [
221221
user_input
222222
for user_input in lowered_exported_program.graph_signature.user_inputs
223-
if user_input in inputs_to_parameters or user_input in inputs_to_buffers
223+
if user_input not in inputs_to_parameters
224+
and user_input not in inputs_to_buffers
224225
]
225226
lowered_exported_program.graph_signature.buffers = {}
226227
lowered_exported_program.graph_signature.parameters = {}

0 commit comments

Comments
 (0)