16
16
from torch .export import export
17
17
18
18
19
- def make_test ( # noqa: C901
20
- tester : unittest .TestCase ,
21
- runtime : ModuleType ,
22
- ) -> Callable [[unittest .TestCase ], None ]:
23
- """
24
- Returns a function that operates as a test case within a unittest.TestCase class.
19
+ class ModuleAdd (torch .nn .Module ):
20
+ """The module to serialize and execute."""
25
21
26
- Used to allow the test code for pybindings to be shared across different pybinding libs
27
- which will all have different load functions. In this case each individual test case is a
28
- subfunction of wrapper.
29
- """
30
- load_fn : Callable = runtime ._load_for_executorch_from_buffer
22
+ def __init__ (self ):
23
+ super (ModuleAdd , self ).__init__ ()
31
24
32
- def wrapper (tester : unittest .TestCase ) -> None :
33
- class ModuleAdd (torch .nn .Module ):
34
- """The module to serialize and execute."""
25
+ def forward (self , x , y ):
26
+ return x + y
35
27
36
- def __init__ (self ):
37
- super ( ModuleAdd , self ). __init__ ( )
28
+ def get_methods_to_export (self ):
29
+ return ( "forward" , )
38
30
39
- def forward (self , x , y ):
40
- return x + y
31
+ def get_inputs (self ):
32
+ return ( torch . ones ( 2 , 2 ), torch . ones ( 2 , 2 ))
41
33
42
- def get_methods_to_export (self ):
43
- return ("forward" ,)
44
34
45
- def get_inputs ( self ):
46
- return ( torch . ones ( 2 , 2 ), torch . ones ( 2 , 2 ))
35
+ class ModuleMulti ( torch . nn . Module ):
36
+ """The module to serialize and execute."""
47
37
48
- class ModuleMulti ( torch . nn . Module ):
49
- """The module to serialize and execute."""
38
+ def __init__ ( self ):
39
+ super ( ModuleMulti , self ). __init__ ()
50
40
51
- def __init__ (self ):
52
- super ( ModuleMulti , self ). __init__ ()
41
+ def forward (self , x , y ):
42
+ return x + y
53
43
54
- def forward (self , x , y ):
55
- return x + y
44
+ def forward2 (self , x , y ):
45
+ return x + y + 1
56
46
57
- def forward2 (self , x , y ):
58
- return x + y + 1
47
+ def get_methods_to_export (self ):
48
+ return ( "forward" , "forward2" )
59
49
60
- def get_methods_to_export (self ):
61
- return ( "forward" , "forward2" )
50
+ def get_inputs (self ):
51
+ return ( torch . ones ( 2 , 2 ), torch . ones ( 2 , 2 ) )
62
52
63
- def get_inputs (self ):
64
- return (torch .ones (2 , 2 ), torch .ones (2 , 2 ))
65
53
66
- class ModuleAddSingleInput (torch .nn .Module ):
67
- """The module to serialize and execute."""
54
+ class ModuleAddSingleInput (torch .nn .Module ):
55
+ """The module to serialize and execute."""
68
56
69
- def __init__ (self ):
70
- super (ModuleAddSingleInput , self ).__init__ ()
57
+ def __init__ (self ):
58
+ super (ModuleAddSingleInput , self ).__init__ ()
71
59
72
- def forward (self , x ):
73
- return x + x
60
+ def forward (self , x ):
61
+ return x + x
74
62
75
- def get_methods_to_export (self ):
76
- return ("forward" ,)
63
+ def get_methods_to_export (self ):
64
+ return ("forward" ,)
77
65
78
- def get_inputs (self ):
79
- return (torch .ones (2 , 2 ),)
66
+ def get_inputs (self ):
67
+ return (torch .ones (2 , 2 ),)
80
68
81
- class ModuleAddConstReturn (torch .nn .Module ):
82
- """The module to serialize and execute."""
83
69
84
- def __init__ (self ):
85
- super (ModuleAddConstReturn , self ).__init__ ()
86
- self .state = torch .ones (2 , 2 )
70
+ class ModuleAddConstReturn (torch .nn .Module ):
71
+ """The module to serialize and execute."""
87
72
88
- def forward (self , x ):
89
- return x + self .state , self .state
73
+ def __init__ (self ):
74
+ super (ModuleAddConstReturn , self ).__init__ ()
75
+ self .state = torch .ones (2 , 2 )
90
76
91
- def get_methods_to_export (self ):
92
- return ( "forward" ,)
77
+ def forward (self , x ):
78
+ return x + self . state , self . state
93
79
94
- def get_inputs (self ):
95
- return (torch . ones ( 2 , 2 ) ,)
80
+ def get_methods_to_export (self ):
81
+ return ("forward" ,)
96
82
97
- def create_program (
98
- eager_module : torch .nn .Module ,
99
- et_config : Optional [ExecutorchBackendConfig ] = None ,
100
- ) -> Tuple [ExecutorchProgramManager , Tuple [Any , ...]]:
101
- """Returns an executorch program based on ModuleAdd, along with inputs."""
83
+ def get_inputs (self ):
84
+ return (torch .ones (2 , 2 ),)
102
85
103
- # Trace the test module and create a serialized ExecuTorch program.
104
- inputs = eager_module .get_inputs ()
105
- input_map = {}
106
- for method in eager_module .get_methods_to_export ():
107
- input_map [method ] = inputs
108
86
109
- class WrapperModule (torch .nn .Module ):
110
- def __init__ (self , fn ):
111
- super ().__init__ ()
112
- self .fn = fn
87
+ def create_program (
88
+ eager_module : torch .nn .Module ,
89
+ et_config : Optional [ExecutorchBackendConfig ] = None ,
90
+ ) -> Tuple [ExecutorchProgramManager , Tuple [Any , ...]]:
91
+ """Returns an executorch program based on ModuleAdd, along with inputs."""
113
92
114
- def forward (self , * args , ** kwargs ):
115
- return self .fn (* args , ** kwargs )
93
+ # Trace the test module and create a serialized ExecuTorch program.
94
+ inputs = eager_module .get_inputs ()
95
+ input_map = {}
96
+ for method in eager_module .get_methods_to_export ():
97
+ input_map [method ] = inputs
116
98
117
- exported_methods = {}
118
- # These cleanup passes are required to convert the `add` op to its out
119
- # variant, along with some other transformations.
120
- for method_name , method_input in input_map .items ():
121
- wrapped_mod = WrapperModule ( # pyre-ignore[16]
122
- getattr (eager_module , method_name )
123
- )
124
- exported_methods [method_name ] = export (wrapped_mod , method_input )
99
+ class WrapperModule (torch .nn .Module ):
100
+ def __init__ (self , fn ):
101
+ super ().__init__ ()
102
+ self .fn = fn
103
+
104
+ def forward (self , * args , ** kwargs ):
105
+ return self .fn (* args , ** kwargs )
106
+
107
+ exported_methods = {}
108
+ # These cleanup passes are required to convert the `add` op to its out
109
+ # variant, along with some other transformations.
110
+ for method_name , method_input in input_map .items ():
111
+ wrapped_mod = WrapperModule (getattr (eager_module , method_name ))
112
+ exported_methods [method_name ] = export (wrapped_mod , method_input )
113
+
114
+ exec_prog = to_edge (exported_methods ).to_executorch (config = et_config )
125
115
126
- exec_prog = to_edge (exported_methods ).to_executorch (config = et_config )
116
+ # Create the ExecuTorch program from the graph.
117
+ exec_prog .dump_executorch_program (verbose = True )
118
+ return (exec_prog , inputs )
127
119
128
- # Create the ExecuTorch program from the graph.
129
- exec_prog .dump_executorch_program (verbose = True )
130
- return (exec_prog , inputs )
120
+
121
+ def make_test ( # noqa: C901
122
+ tester : unittest .TestCase ,
123
+ runtime : ModuleType ,
124
+ ) -> Callable [[unittest .TestCase ], None ]:
125
+ """
126
+ Returns a function that operates as a test case within a unittest.TestCase class.
127
+
128
+ Used to allow the test code for pybindings to be shared across different pybinding libs
129
+ which will all have different load functions. In this case each individual test case is a
130
+ subfunction of wrapper.
131
+ """
132
+ load_fn : Callable = runtime ._load_for_executorch_from_buffer
133
+
134
+ def wrapper (tester : unittest .TestCase ) -> None :
131
135
132
136
######### TEST CASES #########
133
137
@@ -298,7 +302,6 @@ def test_constant_output_not_memory_planned(tester):
298
302
tester .assertEqual (str (torch .ones (2 , 2 )), str (executorch_output [1 ]))
299
303
300
304
def test_method_meta (tester ) -> None :
301
- # pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
302
305
exported_program , inputs = create_program (ModuleAdd ())
303
306
304
307
# Use pybindings to load the program and query its metadata.
@@ -345,7 +348,6 @@ def test_method_meta(tester) -> None:
345
348
346
349
def test_bad_name (tester ) -> None :
347
350
# Create an ExecuTorch program from ModuleAdd.
348
- # pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
349
351
exported_program , inputs = create_program (ModuleAdd ())
350
352
351
353
# Use pybindings to load and execute the program.
@@ -356,7 +358,6 @@ def test_bad_name(tester) -> None:
356
358
357
359
def test_verification_config (tester ) -> None :
358
360
# Create an ExecuTorch program from ModuleAdd.
359
- # pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
360
361
exported_program , inputs = create_program (ModuleAdd ())
361
362
Verification = runtime .Verification
362
363
0 commit comments