16
16
from torch .export import export
17
17
18
18
19
- def make_test ( # noqa: C901
20
- tester : unittest .TestCase ,
21
- runtime : ModuleType ,
22
- ) -> Callable [[unittest .TestCase ], None ]:
23
- """
24
- Returns a function that operates as a test case within a unittest.TestCase class.
19
+ class ModuleAdd (torch .nn .Module ):
20
+ """The module to serialize and execute."""
25
21
26
- Used to allow the test code for pybindings to be shared across different pybinding libs
27
- which will all have different load functions. In this case each individual test case is a
28
- subfunction of wrapper.
29
- """
30
- load_fn : Callable = runtime ._load_for_executorch_from_buffer
22
+ def __init__ (self ):
23
+ super (ModuleAdd , self ).__init__ ()
31
24
32
- def wrapper (tester : unittest .TestCase ) -> None :
33
- class ModuleAdd (torch .nn .Module ):
34
- """The module to serialize and execute."""
25
+ def forward (self , x , y ):
26
+ return x + y
35
27
36
- def __init__ (self ):
37
- super ( ModuleAdd , self ). __init__ ( )
28
+ def get_methods_to_export (self ):
29
+ return ( "forward" , )
38
30
39
- def forward (self , x , y ):
40
- return x + y
31
+ def get_inputs (self ):
32
+ return ( torch . ones ( 2 , 2 ), torch . ones ( 2 , 2 ))
41
33
42
- def get_methods_to_export (self ):
43
- return ("forward" ,)
44
34
45
- def get_inputs ( self ):
46
- return ( torch . ones ( 2 , 2 ), torch . ones ( 2 , 2 ))
35
+ class ModuleMulti ( torch . nn . Module ):
36
+ """The module to serialize and execute."""
47
37
48
- class ModuleMulti ( torch . nn . Module ):
49
- """The module to serialize and execute."""
38
+ def __init__ ( self ):
39
+ super ( ModuleMulti , self ). __init__ ()
50
40
51
- def __init__ (self ):
52
- super ( ModuleMulti , self ). __init__ ()
41
+ def forward (self , x , y ):
42
+ return x + y
53
43
54
- def forward (self , x , y ):
55
- return x + y
44
+ def forward2 (self , x , y ):
45
+ return x + y + 1
56
46
57
- def forward2 (self , x , y ):
58
- return x + y + 1
47
+ def get_methods_to_export (self ):
48
+ return ( "forward" , "forward2" )
59
49
60
- def get_methods_to_export (self ):
61
- return ( "forward" , "forward2" )
50
+ def get_inputs (self ):
51
+ return ( torch . ones ( 2 , 2 ), torch . ones ( 2 , 2 ) )
62
52
63
- def get_inputs (self ):
64
- return (torch .ones (2 , 2 ), torch .ones (2 , 2 ))
65
53
66
- class ModuleAddSingleInput (torch .nn .Module ):
67
- """The module to serialize and execute."""
54
+ class ModuleAddSingleInput (torch .nn .Module ):
55
+ """The module to serialize and execute."""
68
56
69
- def __init__ (self ):
70
- super (ModuleAddSingleInput , self ).__init__ ()
57
+ def __init__ (self ):
58
+ super (ModuleAddSingleInput , self ).__init__ ()
71
59
72
- def forward (self , x ):
73
- return x + x
60
+ def forward (self , x ):
61
+ return x + x
74
62
75
- def get_methods_to_export (self ):
76
- return ("forward" ,)
63
+ def get_methods_to_export (self ):
64
+ return ("forward" ,)
77
65
78
- def get_inputs (self ):
79
- return (torch .ones (2 , 2 ),)
66
+ def get_inputs (self ):
67
+ return (torch .ones (2 , 2 ),)
80
68
81
- class ModuleAddConstReturn (torch .nn .Module ):
82
- """The module to serialize and execute."""
83
69
84
- def __init__ (self ):
85
- super (ModuleAddConstReturn , self ).__init__ ()
86
- self .state = torch .ones (2 , 2 )
70
+ class ModuleAddConstReturn (torch .nn .Module ):
71
+ """The module to serialize and execute."""
87
72
88
- def forward (self , x ):
89
- return x + self .state , self .state
73
+ def __init__ (self ):
74
+ super (ModuleAddConstReturn , self ).__init__ ()
75
+ self .state = torch .ones (2 , 2 )
90
76
91
- def get_methods_to_export (self ):
92
- return ( "forward" ,)
77
+ def forward (self , x ):
78
+ return x + self . state , self . state
93
79
94
- def get_inputs (self ):
95
- return (torch . ones ( 2 , 2 ) ,)
80
+ def get_methods_to_export (self ):
81
+ return ("forward" ,)
96
82
97
- def create_program (
98
- eager_module : torch .nn .Module ,
99
- et_config : Optional [ExecutorchBackendConfig ] = None ,
100
- ) -> Tuple [ExecutorchProgramManager , Tuple [Any , ...]]:
101
- """Returns an executorch program based on ModuleAdd, along with inputs."""
83
+ def get_inputs (self ):
84
+ return (torch .ones (2 , 2 ),)
102
85
103
- # Trace the test module and create a serialized ExecuTorch program.
104
- inputs = eager_module .get_inputs ()
105
- input_map = {}
106
- for method in eager_module .get_methods_to_export ():
107
- input_map [method ] = inputs
108
86
109
- class WrapperModule (torch .nn .Module ):
110
- def __init__ (self , fn ):
111
- super ().__init__ ()
112
- self .fn = fn
87
+ class ModuleAddConstReturn (torch .nn .Module ):
88
+ """The module to serialize and execute."""
113
89
114
- def forward (self , * args , ** kwargs ):
115
- return self .fn (* args , ** kwargs )
90
+ def __init__ (self ):
91
+ super (ModuleAddConstReturn , self ).__init__ ()
92
+ self .state = torch .ones (2 , 2 )
116
93
117
- exported_methods = {}
118
- # These cleanup passes are required to convert the `add` op to its out
119
- # variant, along with some other transformations.
120
- for method_name , method_input in input_map .items ():
121
- wrapped_mod = WrapperModule ( # pyre-ignore[16]
122
- getattr (eager_module , method_name )
123
- )
124
- exported_methods [method_name ] = export (wrapped_mod , method_input )
94
+ def forward (self , x ):
95
+ return x + self .state , self .state
96
+
97
+ def get_methods_to_export (self ):
98
+ return ("forward" ,)
99
+
100
+ def get_inputs (self ):
101
+ return (torch .ones (2 , 2 ),)
102
+
103
+
104
+ def create_program (
105
+ eager_module : torch .nn .Module ,
106
+ et_config : Optional [ExecutorchBackendConfig ] = None ,
107
+ ) -> Tuple [ExecutorchProgramManager , Tuple [Any , ...]]:
108
+ """Returns an executorch program based on ModuleAdd, along with inputs."""
109
+
110
+ # Trace the test module and create a serialized ExecuTorch program.
111
+ inputs = eager_module .get_inputs ()
112
+ input_map = {}
113
+ for method in eager_module .get_methods_to_export ():
114
+ input_map [method ] = inputs
115
+
116
+ class WrapperModule (torch .nn .Module ):
117
+ def __init__ (self , fn ):
118
+ super ().__init__ ()
119
+ self .fn = fn
125
120
126
- exec_prog = to_edge (exported_methods ).to_executorch (config = et_config )
121
+ def forward (self , * args , ** kwargs ):
122
+ return self .fn (* args , ** kwargs )
127
123
128
- # Create the ExecuTorch program from the graph.
129
- exec_prog .dump_executorch_program (verbose = True )
130
- return (exec_prog , inputs )
124
+ exported_methods = {}
125
+ # These cleanup passes are required to convert the `add` op to its out
126
+ # variant, along with some other transformations.
127
+ for method_name , method_input in input_map .items ():
128
+ wrapped_mod = WrapperModule (getattr (eager_module , method_name ))
129
+ exported_methods [method_name ] = export (wrapped_mod , method_input )
130
+
131
+ exec_prog = to_edge (exported_methods ).to_executorch (config = et_config )
132
+
133
+ # Create the ExecuTorch program from the graph.
134
+ exec_prog .dump_executorch_program (verbose = True )
135
+ return (exec_prog , inputs )
136
+
137
+
138
+ def make_test ( # noqa: C901
139
+ tester : unittest .TestCase ,
140
+ runtime : ModuleType ,
141
+ ) -> Callable [[unittest .TestCase ], None ]:
142
+ """
143
+ Returns a function that operates as a test case within a unittest.TestCase class.
144
+
145
+ Used to allow the test code for pybindings to be shared across different pybinding libs
146
+ which will all have different load functions. In this case each individual test case is a
147
+ subfunction of wrapper.
148
+ """
149
+ load_fn : Callable = runtime ._load_for_executorch_from_buffer
150
+
151
+ def wrapper (tester : unittest .TestCase ) -> None :
131
152
132
153
######### TEST CASES #########
133
154
@@ -300,7 +321,6 @@ def test_constant_output_not_memory_planned(tester):
300
321
######### RUN TEST CASES #########
301
322
302
323
def test_method_meta (tester ) -> None :
303
- # pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
304
324
exported_program , inputs = create_program (ModuleAdd ())
305
325
306
326
# Use pybindings to load the program and query its metadata.
@@ -347,7 +367,6 @@ def test_method_meta(tester) -> None:
347
367
348
368
def test_bad_name (tester ) -> None :
349
369
# Create an ExecuTorch program from ModuleAdd.
350
- # pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
351
370
exported_program , inputs = create_program (ModuleAdd ())
352
371
353
372
# Use pybindings to load and execute the program.
@@ -358,7 +377,6 @@ def test_bad_name(tester) -> None:
358
377
359
378
def test_verification_config (tester ) -> None :
360
379
# Create an ExecuTorch program from ModuleAdd.
361
- # pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
362
380
exported_program , inputs = create_program (ModuleAdd ())
363
381
Verification = runtime .Verification
364
382
0 commit comments