Skip to content

Commit 687cbcc

Browse files
committed
Update module wrapper so that params are explicitly registered to the wrapper
Seeing issue with linear where the fqns for constants disappear. Registering self.method_name as a submodule of wrapper means that the parameters are registered to the wrapper. cc angelayi ``` File "/data/users/lfq/fbsource/buck-out/v2/gen/fbcode/1af94fa701700343/executorch/test/models/__export_delegated_program__/export_delegated_program#link-tree/torch/export/_trace.py", line 1980, in _export_for_training export_artifact = export_func( File "/data/users/lfq/fbsource/buck-out/v2/gen/fbcode/1af94fa701700343/executorch/test/models/__export_delegated_program__/export_delegated_program#link-tree/torch/export/_trace.py", line 1473, in _strict_export _replace_param_buffer_names(param_buffer_table, export_graph_signature) File "/data/users/lfq/fbsource/buck-out/v2/gen/fbcode/1af94fa701700343/executorch/test/models/__export_delegated_program__/export_delegated_program#link-tree/torch/export/_trace.py", line 272, in _replace_param_buffer_names spec.target = param_buffer_table[spec.target] KeyError: 'L__self___fn___self___linear.weight' ``` Differential Revision: [D73279618](https://our.internmc.facebook.com/intern/diff/D73279618/) ghstack-source-id: 279019422 Pull Request resolved: #10305
1 parent 9ce83e3 commit 687cbcc

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

test/models/export_delegated_program.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,18 @@ def get_random_inputs(self) -> Sequence[torch.Tensor]:
8787
n = 10 # to create a large tensor
8888
return (torch.ones(n, n, n), 2 * torch.ones(n, n, n), 3 * torch.ones(n, n, n))
8989

90+
91+
class ModuleLinear(torch.nn.Module):
92+
def __init__(self):
93+
super().__init__()
94+
self.linear = torch.nn.Linear(3, 3)
95+
96+
def forward(self, x: torch.Tensor):
97+
return self.linear(x)
98+
99+
def get_random_inputs(self):
100+
return (torch.randn(3),)
101+
90102

91103
#
92104
# Backends
@@ -114,24 +126,23 @@ def export_module_to_program(
114126
extract_delegate_segments: bool,
115127
constant_tensor_alignment: Optional[int] = None,
116128
delegate_alignment: Optional[int] = None,
117-
method: str = "forward",
129+
method_name: str = "forward",
118130
) -> ExecutorchProgramManager:
119131
eager_module = module_class().eval()
120132
inputs = ()
121133
if hasattr(eager_module, "get_random_inputs"):
122134
inputs = eager_module.get_random_inputs() # type: ignore[operator]
123135

124136
class WrapperModule(torch.nn.Module):
125-
def __init__(self, fn):
137+
def __init__(self, fn, method_name=method_name):
126138
super().__init__()
127139
self.fn = fn
140+
self.method_name = method_name
128141

129142
def forward(self, *args, **kwargs):
130-
return self.fn(*args, **kwargs)
143+
return getattr(self.fn, self.method_name)(*args, **kwargs)
131144

132-
exported_program = export(
133-
WrapperModule(getattr(eager_module, method)), args=inputs, strict=True
134-
)
145+
exported_program = export(WrapperModule(eager_module), args=inputs, strict=True)
135146

136147
edge_config = EdgeCompileConfig(_check_ir_validity=False)
137148
et_config = exir.ExecutorchBackendConfig(

test/models/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def define_common_targets():
155155
"ModuleAddMul",
156156
"ModuleAddLarge",
157157
"ModuleSubLarge",
158+
"ModuleLinear",
158159
]
159160

160161
# Name of the backend to use when exporting delegated programs.

0 commit comments

Comments
 (0)