Skip to content

Commit 73f4118

Browse files
authored
Update module wrapper so that params are explicitly registered to the wrapper
Differential Revision: D73279618 Pull Request resolved: #10305
1 parent 0b68249 commit 73f4118

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

test/models/export_delegated_program.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,18 @@ def get_random_inputs(self) -> Sequence[torch.Tensor]:
9090
return (torch.ones(n, n, n), 2 * torch.ones(n, n, n), 3 * torch.ones(n, n, n))
9191

9292

93+
class ModuleLinear(torch.nn.Module):
94+
def __init__(self):
95+
super().__init__()
96+
self.linear = torch.nn.Linear(3, 3)
97+
98+
def forward(self, x: torch.Tensor):
99+
return self.linear(x)
100+
101+
def get_random_inputs(self):
102+
return (torch.randn(3),)
103+
104+
93105
#
94106
# Backends
95107
#
@@ -116,24 +128,23 @@ def export_module_to_program(
116128
extract_delegate_segments: bool,
117129
constant_tensor_alignment: Optional[int] = None,
118130
delegate_alignment: Optional[int] = None,
119-
method: str = "forward",
131+
method_name: str = "forward",
120132
) -> ExecutorchProgramManager:
121133
eager_module = module_class().eval()
122134
inputs = ()
123135
if hasattr(eager_module, "get_random_inputs"):
124136
inputs = eager_module.get_random_inputs() # type: ignore[operator]
125137

126138
class WrapperModule(torch.nn.Module):
127-
def __init__(self, fn):
139+
def __init__(self, fn, method_name=method_name):
128140
super().__init__()
129141
self.fn = fn
142+
self.method_name = method_name
130143

131144
def forward(self, *args, **kwargs):
132-
return self.fn(*args, **kwargs)
145+
return getattr(self.fn, self.method_name)(*args, **kwargs)
133146

134-
exported_program = export(
135-
WrapperModule(getattr(eager_module, method)), args=inputs, strict=True
136-
)
147+
exported_program = export(WrapperModule(eager_module), args=inputs, strict=True)
137148

138149
edge_config = EdgeCompileConfig(_check_ir_validity=False)
139150
et_config = exir.ExecutorchBackendConfig(

test/models/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def define_common_targets():
156156
"ModuleAddMul",
157157
"ModuleAddLarge",
158158
"ModuleSubLarge",
159+
"ModuleLinear",
159160
]
160161

161162
# Name of the backend to use when exporting delegated programs.

0 commit comments

Comments
 (0)