15
15
from torch .export import export
16
16
17
17
18
- def make_test ( # noqa: C901
19
- tester : unittest .TestCase ,
20
- runtime : ModuleType ,
21
- ) -> Callable [[unittest .TestCase ], None ]:
22
- """
23
- Returns a function that operates as a test case within a unittest.TestCase class.
18
+ class ModuleAdd (torch .nn .Module ):
19
+ """The module to serialize and execute."""
24
20
25
- Used to allow the test code for pybindings to be shared across different pybinding libs
26
- which will all have different load functions. In this case each individual test case is a
27
- subfunction of wrapper.
28
- """
29
- load_fn : Callable = runtime ._load_for_executorch_from_buffer
21
+ def __init__ (self ):
22
+ super (ModuleAdd , self ).__init__ ()
30
23
31
- def wrapper (tester : unittest .TestCase ) -> None :
32
- class ModuleAdd (torch .nn .Module ):
33
- """The module to serialize and execute."""
24
+ def forward (self , x , y ):
25
+ return x + y
34
26
35
- def __init__ (self ):
36
- super ( ModuleAdd , self ). __init__ ( )
27
+ def get_methods_to_export (self ):
28
+ return ( "forward" , )
37
29
38
- def forward (self , x , y ):
39
- return x + y
30
+ def get_inputs (self ):
31
+ return ( torch . ones ( 2 , 2 ), torch . ones ( 2 , 2 ))
40
32
41
- def get_methods_to_export (self ):
42
- return ("forward" ,)
43
33
44
- def get_inputs ( self ):
45
- return ( torch . ones ( 2 , 2 ), torch . ones ( 2 , 2 ))
34
+ class ModuleMulti ( torch . nn . Module ):
35
+ """The module to serialize and execute."""
46
36
47
- class ModuleMulti ( torch . nn . Module ):
48
- """The module to serialize and execute."""
37
+ def __init__ ( self ):
38
+ super ( ModuleMulti , self ). __init__ ()
49
39
50
- def __init__ (self ):
51
- super ( ModuleMulti , self ). __init__ ()
40
+ def forward (self , x , y ):
41
+ return x + y
52
42
53
- def forward (self , x , y ):
54
- return x + y
43
+ def forward2 (self , x , y ):
44
+ return x + y + 1
55
45
56
- def forward2 (self , x , y ):
57
- return x + y + 1
46
+ def get_methods_to_export (self ):
47
+ return ( "forward" , "forward2" )
58
48
59
- def get_methods_to_export (self ):
60
- return ( "forward" , "forward2" )
49
+ def get_inputs (self ):
50
+ return ( torch . ones ( 2 , 2 ), torch . ones ( 2 , 2 ) )
61
51
62
- def get_inputs (self ):
63
- return (torch .ones (2 , 2 ), torch .ones (2 , 2 ))
64
52
65
- class ModuleAddSingleInput (torch .nn .Module ):
66
- """The module to serialize and execute."""
53
+ class ModuleAddSingleInput (torch .nn .Module ):
54
+ """The module to serialize and execute."""
67
55
68
- def __init__ (self ):
69
- super (ModuleAddSingleInput , self ).__init__ ()
56
+ def __init__ (self ):
57
+ super (ModuleAddSingleInput , self ).__init__ ()
70
58
71
- def forward (self , x ):
72
- return x + x
59
+ def forward (self , x ):
60
+ return x + x
73
61
74
- def get_methods_to_export (self ):
75
- return ("forward" ,)
62
+ def get_methods_to_export (self ):
63
+ return ("forward" ,)
76
64
77
- def get_inputs (self ):
78
- return (torch .ones (2 , 2 ),)
65
+ def get_inputs (self ):
66
+ return (torch .ones (2 , 2 ),)
79
67
80
- def create_program (
81
- eager_module : torch .nn .Module ,
82
- ) -> Tuple [ExecutorchProgramManager , Tuple [Any , ...]]:
83
- """Returns an executorch program based on ModuleAdd, along with inputs."""
84
68
85
- # Trace the test module and create a serialized ExecuTorch program.
86
- inputs = eager_module .get_inputs ()
87
- input_map = {}
88
- for method in eager_module .get_methods_to_export ():
89
- input_map [method ] = inputs
69
+ class ModuleAddConstReturn (torch .nn .Module ):
70
+ """The module to serialize and execute."""
90
71
91
- class WrapperModule (torch .nn .Module ):
92
- def __init__ (self , fn ):
93
- super ().__init__ ()
94
- self .fn = fn
72
+ def __init__ (self ):
73
+ super (ModuleAddConstReturn , self ).__init__ ()
74
+ self .state = torch .ones (2 , 2 )
95
75
96
- def forward (self , * args , ** kwargs ):
97
- return self .fn ( * args , ** kwargs )
76
+ def forward (self , x ):
77
+ return x + self .state , self . state
98
78
99
- exported_methods = {}
100
- # These cleanup passes are required to convert the `add` op to its out
101
- # variant, along with some other transformations.
102
- for method_name , method_input in input_map .items ():
103
- wrapped_mod = WrapperModule ( # pyre-ignore[16]
104
- getattr (eager_module , method_name )
105
- )
106
- exported_methods [method_name ] = export (wrapped_mod , method_input )
79
+ def get_methods_to_export (self ):
80
+ return ("forward" ,)
81
+
82
+ def get_inputs (self ):
83
+ return (torch .ones (2 , 2 ),)
84
+
85
+
86
+ def create_program (
87
+ eager_module : torch .nn .Module ,
88
+ et_config : Optional [ExecutorchBackendConfig ] = None ,
89
+ ) -> Tuple [ExecutorchProgramManager , Tuple [Any , ...]]:
90
+ """Returns an executorch program based on ModuleAdd, along with inputs."""
91
+
92
+ # Trace the test module and create a serialized ExecuTorch program.
93
+ inputs = eager_module .get_inputs ()
94
+ input_map = {}
95
+ for method in eager_module .get_methods_to_export ():
96
+ input_map [method ] = inputs
97
+
98
+ class WrapperModule (torch .nn .Module ):
99
+ def __init__ (self , fn ):
100
+ super ().__init__ ()
101
+ self .fn = fn
107
102
108
- exec_prog = to_edge (exported_methods ).to_executorch ()
103
+ def forward (self , * args , ** kwargs ):
104
+ return self .fn (* args , ** kwargs )
109
105
110
- # Create the ExecuTorch program from the graph.
111
- exec_prog .dump_executorch_program (verbose = True )
112
- return (exec_prog , inputs )
106
+ exported_methods = {}
107
+ # These cleanup passes are required to convert the `add` op to its out
108
+ # variant, along with some other transformations.
109
+ for method_name , method_input in input_map .items ():
110
+ wrapped_mod = WrapperModule (getattr (eager_module , method_name ))
111
+ exported_methods [method_name ] = export (wrapped_mod , method_input )
112
+
113
+ exec_prog = to_edge (exported_methods ).to_executorch (config = et_config )
114
+
115
+ # Create the ExecuTorch program from the graph.
116
+ exec_prog .dump_executorch_program (verbose = True )
117
+ return (exec_prog , inputs )
118
+
119
+
120
+ def make_test ( # noqa: C901
121
+ tester : unittest .TestCase ,
122
+ runtime : ModuleType ,
123
+ ) -> Callable [[unittest .TestCase ], None ]:
124
+ """
125
+ Returns a function that operates as a test case within a unittest.TestCase class.
126
+
127
+ Used to allow the test code for pybindings to be shared across different pybinding libs
128
+ which will all have different load functions. In this case each individual test case is a
129
+ subfunction of wrapper.
130
+ """
131
+ load_fn : Callable = runtime ._load_for_executorch_from_buffer
132
+
133
+ def wrapper (tester : unittest .TestCase ) -> None :
113
134
114
135
######### TEST CASES #########
115
136
@@ -280,7 +301,6 @@ def test_constant_output_not_memory_planned(tester):
280
301
tester .assertEqual (str (torch .ones (2 , 2 )), str (executorch_output [1 ]))
281
302
282
303
def test_method_meta (tester ) -> None :
283
- # pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
284
304
exported_program , inputs = create_program (ModuleAdd ())
285
305
286
306
# Use pybindings to load the program and query its metadata.
@@ -327,7 +347,6 @@ def test_method_meta(tester) -> None:
327
347
328
348
def test_bad_name (tester ) -> None :
329
349
# Create an ExecuTorch program from ModuleAdd.
330
- # pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
331
350
exported_program , inputs = create_program (ModuleAdd ())
332
351
333
352
# Use pybindings to load and execute the program.
@@ -338,7 +357,6 @@ def test_bad_name(tester) -> None:
338
357
339
358
def test_verification_config (tester ) -> None :
340
359
# Create an ExecuTorch program from ModuleAdd.
341
- # pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
342
360
exported_program , inputs = create_program (ModuleAdd ())
343
361
Verification = runtime .Verification
344
362
0 commit comments