@@ -90,6 +90,18 @@ def get_random_inputs(self) -> Sequence[torch.Tensor]:
90
90
return (torch .ones (n , n , n ), 2 * torch .ones (n , n , n ), 3 * torch .ones (n , n , n ))
91
91
92
92
93
+ class ModuleLinear (torch .nn .Module ):
94
+ def __init__ (self ):
95
+ super ().__init__ ()
96
+ self .linear = torch .nn .Linear (3 , 3 )
97
+
98
+ def forward (self , x : torch .Tensor ):
99
+ return self .linear (x )
100
+
101
+ def get_random_inputs (self ):
102
+ return (torch .randn (3 ),)
103
+
104
+
93
105
#
94
106
# Backends
95
107
#
@@ -116,24 +128,23 @@ def export_module_to_program(
116
128
extract_delegate_segments : bool ,
117
129
constant_tensor_alignment : Optional [int ] = None ,
118
130
delegate_alignment : Optional [int ] = None ,
119
- method : str = "forward" ,
131
+ method_name : str = "forward" ,
120
132
) -> ExecutorchProgramManager :
121
133
eager_module = module_class ().eval ()
122
134
inputs = ()
123
135
if hasattr (eager_module , "get_random_inputs" ):
124
136
inputs = eager_module .get_random_inputs () # type: ignore[operator]
125
137
126
138
class WrapperModule (torch .nn .Module ):
127
- def __init__ (self , fn ):
139
+ def __init__ (self , fn , method_name = method_name ):
128
140
super ().__init__ ()
129
141
self .fn = fn
142
+ self .method_name = method_name
130
143
131
144
def forward (self , * args , ** kwargs ):
132
- return self .fn (* args , ** kwargs )
145
+ return getattr ( self .fn , self . method_name ) (* args , ** kwargs )
133
146
134
- exported_program = export (
135
- WrapperModule (getattr (eager_module , method )), args = inputs , strict = True
136
- )
147
+ exported_program = export (WrapperModule (eager_module ), args = inputs , strict = True )
137
148
138
149
edge_config = EdgeCompileConfig (_check_ir_validity = False )
139
150
et_config = exir .ExecutorchBackendConfig (
0 commit comments