Skip to content

Commit 6a82493

Browse files
pianpwkfacebook-github-bot
authored andcommitted
Restore original placeholder names (part 1: top-level renaming) (#2859)
Summary: Pull Request resolved: #2859 X-link: pytorch/pytorch#122904 note: breaking the original diff [D55225818](https://www.internalfb.com/diff/D55225818) into 3 parts (top-level renaming, higher-order-op subgraphs, constant input de/serialization) because of its size. This PR restores original names to placeholder nodes, replacing the default names arg0_1, arg1_1, and so on. User inputs now follow the signature of mod.forward(), for example forward(x, y) produces nodes x, y. If the tensors are nested in dictionaries, lists, tuples, or dataclasses, the names are a concatenation of the path to the tensor, e.g. x = {'a': torch.randn(4), 'b': [torch.randn(4), torch.randn(4)]} produces nodes x_a, x_b_0, x_b_1. Parameters, buffers, constants, and custom objects follow the FQN of the object, prefixed by "p", "b", "c", and "obj" respectively. For example, self.bar.l0.weight gets you p_bar_l0_weight. Effect tokens are named token_1, token_2, and so on, since they are not grounded in model inputs or named attributes. Naming collisions between nodes are handled in the existing way with count suffixing. For collisions between placeholders and non-placeholder nodes, placeholders are prioritized (e.g. forward(self, mul, add) will lead to mul & add call_function nodes being suffixed). NOTE: Apologies in advance if this breaks downstream tests that rely on placeholder names, I imagine there's some tests that aren't being triggered. Currently trying to fix all errors that appear here. Examples: ```python # params, buffers, constants, inputs, torch.cond ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, p_l0_weight: "f32[4, 4]", p_l0_bias: "f32[4]", c_alpha: "f32[4]", b_beta: "f32[4]", x_0_a: "f32[4, 4]", y: "f32[4, 4]"): # No stacktrace found for following nodes mul: "f32[4, 4]" = torch.ops.aten.mul.Tensor(x_0_a, x_0_a) t: "f32[4, 4]" = torch.ops.aten.t.default(p_l0_weight); p_l0_weight = None addmm: "f32[4, 4]" = torch.ops.aten.addmm.default(p_l0_bias, y, t); p_l0_bias = y = t = None return addmm # model code class Bar(torch.nn.Module): def forward(self, x): return x * x class Foo(torch.nn.Module): def __init__(self): super().__init__() self.bar = Bar() self.l0 = torch.nn.Linear(4, 4) self.alpha = torch.randn(4) self.register_buffer('beta', torch.randn(4)) def forward(self, x, y): x = x[0]['a'] mul = self.bar(x) z1 = self.l0(y) return z1 # custom objects, dataclasses, tokens, constant inputs ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, token_1: "f32[0]", obj_attr, data_x: "f32[4, 4]", data_y: "f32[4, 4]", mode): # No stacktrace found for following nodes mul: "f32[4, 4]" = torch.ops.aten.mul.Scalar(data_x, 30); data_x = None div: "f32[4, 4]" = torch.ops.aten.div.Tensor_mode(data_y, 1.0, rounding_mode = 'floor'); data_y = None add: "f32[4, 4]" = torch.ops.aten.add.Tensor(mul, div); mul = div = None with_effects = torch._higher_order_ops.effects.with_effects(token_1, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, add); token_1 = obj_attr = add = None getitem: "f32[0]" = with_effects[0] getitem_1: "f32[4, 4]" = with_effects[1]; with_effects = None return (getitem, getitem_1) # model code class Foo(torch.nn.Module): def __init__(self): super().__init__() self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) def forward(self, data, a=1.0, mode="floor"): x = self.attr.add_tensor(data.x) + torch.div(data.y, a, rounding_mode=mode) x = torch.ops._TorchScriptTesting.takes_foo(self.attr, x) return x dataclass class DataClass: x: Tensor y: Tensor register_dataclass_as_pytree_node( DataClass, serialized_type_name="test.DataClass" ) args = (DataClass(x=torch.randn(4, 4), y=torch.randn(4, 4)), ) kwargs = {'mode': 'floor'} ep = torch.export.export(Foo(), args, kwargs, strict=False) ``` Reviewed By: angelayi Differential Revision: D55456418
1 parent f64130e commit 6a82493

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

exir/backend/test/test_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,6 @@ def partition(
520520
_ = edge.to_backend(PartitionerTagData())
521521

522522
self.assertEqual(
523-
"constant data node (arg0_1) is tagged with (tag0) but has user (aten_sub_tensor) which has tag (None)",
523+
"constant data node (b_const) is tagged with (tag0) but has user (aten_sub_tensor) which has tag (None)",
524524
str(error.exception),
525525
)

0 commit comments

Comments
 (0)