|
16 | 16 | from torch.export import export
|
17 | 17 |
|
18 | 18 |
|
19 |
| -def make_test( # noqa: C901 |
20 |
| - tester: unittest.TestCase, |
21 |
| - runtime: ModuleType, |
22 |
| -) -> Callable[[unittest.TestCase], None]: |
23 |
| - """ |
24 |
| - Returns a function that operates as a test case within a unittest.TestCase class. |
| 19 | +class ModuleAdd(torch.nn.Module): |
| 20 | + """The module to serialize and execute.""" |
25 | 21 |
|
26 |
| - Used to allow the test code for pybindings to be shared across different pybinding libs |
27 |
| - which will all have different load functions. In this case each individual test case is a |
28 |
| - subfunction of wrapper. |
29 |
| - """ |
30 |
| - load_fn: Callable = runtime._load_for_executorch_from_buffer |
| 22 | + def __init__(self): |
| 23 | + super(ModuleAdd, self).__init__() |
31 | 24 |
|
32 |
| - def wrapper(tester: unittest.TestCase) -> None: |
33 |
| - class ModuleAdd(torch.nn.Module): |
34 |
| - """The module to serialize and execute.""" |
| 25 | + def forward(self, x, y): |
| 26 | + return x + y |
35 | 27 |
|
36 |
| - def __init__(self): |
37 |
| - super(ModuleAdd, self).__init__() |
| 28 | + def get_methods_to_export(self): |
| 29 | + return ("forward",) |
38 | 30 |
|
39 |
| - def forward(self, x, y): |
40 |
| - return x + y |
| 31 | + def get_inputs(self): |
| 32 | + return (torch.ones(2, 2), torch.ones(2, 2)) |
41 | 33 |
|
42 |
| - def get_methods_to_export(self): |
43 |
| - return ("forward",) |
44 | 34 |
|
45 |
| - def get_inputs(self): |
46 |
| - return (torch.ones(2, 2), torch.ones(2, 2)) |
| 35 | +class ModuleMulti(torch.nn.Module): |
| 36 | + """The module to serialize and execute.""" |
47 | 37 |
|
48 |
| - class ModuleMulti(torch.nn.Module): |
49 |
| - """The module to serialize and execute.""" |
| 38 | + def __init__(self): |
| 39 | + super(ModuleMulti, self).__init__() |
50 | 40 |
|
51 |
| - def __init__(self): |
52 |
| - super(ModuleMulti, self).__init__() |
| 41 | + def forward(self, x, y): |
| 42 | + return x + y |
53 | 43 |
|
54 |
| - def forward(self, x, y): |
55 |
| - return x + y |
| 44 | + def forward2(self, x, y): |
| 45 | + return x + y + 1 |
56 | 46 |
|
57 |
| - def forward2(self, x, y): |
58 |
| - return x + y + 1 |
| 47 | + def get_methods_to_export(self): |
| 48 | + return ("forward", "forward2") |
59 | 49 |
|
60 |
| - def get_methods_to_export(self): |
61 |
| - return ("forward", "forward2") |
| 50 | + def get_inputs(self): |
| 51 | + return (torch.ones(2, 2), torch.ones(2, 2)) |
62 | 52 |
|
63 |
| - def get_inputs(self): |
64 |
| - return (torch.ones(2, 2), torch.ones(2, 2)) |
65 | 53 |
|
66 |
| - class ModuleAddSingleInput(torch.nn.Module): |
67 |
| - """The module to serialize and execute.""" |
| 54 | +class ModuleAddSingleInput(torch.nn.Module): |
| 55 | + """The module to serialize and execute.""" |
68 | 56 |
|
69 |
| - def __init__(self): |
70 |
| - super(ModuleAddSingleInput, self).__init__() |
| 57 | + def __init__(self): |
| 58 | + super(ModuleAddSingleInput, self).__init__() |
71 | 59 |
|
72 |
| - def forward(self, x): |
73 |
| - return x + x |
| 60 | + def forward(self, x): |
| 61 | + return x + x |
74 | 62 |
|
75 |
| - def get_methods_to_export(self): |
76 |
| - return ("forward",) |
| 63 | + def get_methods_to_export(self): |
| 64 | + return ("forward",) |
77 | 65 |
|
78 |
| - def get_inputs(self): |
79 |
| - return (torch.ones(2, 2),) |
| 66 | + def get_inputs(self): |
| 67 | + return (torch.ones(2, 2),) |
80 | 68 |
|
81 |
| - class ModuleAddConstReturn(torch.nn.Module): |
82 |
| - """The module to serialize and execute.""" |
83 | 69 |
|
84 |
| - def __init__(self): |
85 |
| - super(ModuleAddConstReturn, self).__init__() |
86 |
| - self.state = torch.ones(2, 2) |
| 70 | +class ModuleAddConstReturn(torch.nn.Module): |
| 71 | + """The module to serialize and execute.""" |
87 | 72 |
|
88 |
| - def forward(self, x): |
89 |
| - return x + self.state, self.state |
| 73 | + def __init__(self): |
| 74 | + super(ModuleAddConstReturn, self).__init__() |
| 75 | + self.state = torch.ones(2, 2) |
90 | 76 |
|
91 |
| - def get_methods_to_export(self): |
92 |
| - return ("forward",) |
| 77 | + def forward(self, x): |
| 78 | + return x + self.state, self.state |
93 | 79 |
|
94 |
| - def get_inputs(self): |
95 |
| - return (torch.ones(2, 2),) |
| 80 | + def get_methods_to_export(self): |
| 81 | + return ("forward",) |
96 | 82 |
|
97 |
| - def create_program( |
98 |
| - eager_module: torch.nn.Module, |
99 |
| - et_config: Optional[ExecutorchBackendConfig] = None, |
100 |
| - ) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]: |
101 |
| - """Returns an executorch program based on ModuleAdd, along with inputs.""" |
| 83 | + def get_inputs(self): |
| 84 | + return (torch.ones(2, 2),) |
102 | 85 |
|
103 |
| - # Trace the test module and create a serialized ExecuTorch program. |
104 |
| - inputs = eager_module.get_inputs() |
105 |
| - input_map = {} |
106 |
| - for method in eager_module.get_methods_to_export(): |
107 |
| - input_map[method] = inputs |
108 | 86 |
|
109 |
| - class WrapperModule(torch.nn.Module): |
110 |
| - def __init__(self, fn): |
111 |
| - super().__init__() |
112 |
| - self.fn = fn |
| 87 | +def create_program( |
| 88 | + eager_module: torch.nn.Module, |
| 89 | + et_config: Optional[ExecutorchBackendConfig] = None, |
| 90 | +) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]: |
| 91 | + """Returns an executorch program based on ModuleAdd, along with inputs.""" |
113 | 92 |
|
114 |
| - def forward(self, *args, **kwargs): |
115 |
| - return self.fn(*args, **kwargs) |
| 93 | + # Trace the test module and create a serialized ExecuTorch program. |
| 94 | + inputs = eager_module.get_inputs() |
| 95 | + input_map = {} |
| 96 | + for method in eager_module.get_methods_to_export(): |
| 97 | + input_map[method] = inputs |
116 | 98 |
|
117 |
| - exported_methods = {} |
118 |
| - # These cleanup passes are required to convert the `add` op to its out |
119 |
| - # variant, along with some other transformations. |
120 |
| - for method_name, method_input in input_map.items(): |
121 |
| - wrapped_mod = WrapperModule( # pyre-ignore[16] |
122 |
| - getattr(eager_module, method_name) |
123 |
| - ) |
124 |
| - exported_methods[method_name] = export(wrapped_mod, method_input) |
| 99 | + class WrapperModule(torch.nn.Module): |
| 100 | + def __init__(self, fn): |
| 101 | + super().__init__() |
| 102 | + self.fn = fn |
| 103 | + |
| 104 | + def forward(self, *args, **kwargs): |
| 105 | + return self.fn(*args, **kwargs) |
| 106 | + |
| 107 | + exported_methods = {} |
| 108 | + # These cleanup passes are required to convert the `add` op to its out |
| 109 | + # variant, along with some other transformations. |
| 110 | + for method_name, method_input in input_map.items(): |
| 111 | + wrapped_mod = WrapperModule( # pyre-ignore[16] |
| 112 | + getattr(eager_module, method_name) |
| 113 | + ) |
| 114 | + exported_methods[method_name] = export(wrapped_mod, method_input) |
| 115 | + |
| 116 | + exec_prog = to_edge(exported_methods).to_executorch(config=et_config) |
125 | 117 |
|
126 |
| - exec_prog = to_edge(exported_methods).to_executorch(config=et_config) |
| 118 | + # Create the ExecuTorch program from the graph. |
| 119 | + exec_prog.dump_executorch_program(verbose=True) |
| 120 | + return (exec_prog, inputs) |
127 | 121 |
|
128 |
| - # Create the ExecuTorch program from the graph. |
129 |
| - exec_prog.dump_executorch_program(verbose=True) |
130 |
| - return (exec_prog, inputs) |
| 122 | + |
| 123 | +def make_test( # noqa: C901 |
| 124 | + tester: unittest.TestCase, |
| 125 | + runtime: ModuleType, |
| 126 | +) -> Callable[[unittest.TestCase], None]: |
| 127 | + """ |
| 128 | + Returns a function that operates as a test case within a unittest.TestCase class. |
| 129 | +
|
| 130 | + Used to allow the test code for pybindings to be shared across different pybinding libs |
| 131 | + which will all have different load functions. In this case each individual test case is a |
| 132 | + subfunction of wrapper. |
| 133 | + """ |
| 134 | + load_fn: Callable = runtime._load_for_executorch_from_buffer |
| 135 | + |
| 136 | + def wrapper(tester: unittest.TestCase) -> None: |
131 | 137 |
|
132 | 138 | ######### TEST CASES #########
|
133 | 139 |
|
|
0 commit comments