|
10 | 10 |
|
11 | 11 | import executorch.exir as exir
|
12 | 12 | import torch
|
13 |
| -import torch.fx as fx |
14 |
| -from executorch.exir import multi_method_program_to_executorch |
15 |
| -from executorch.exir.backend.backend_api import ( |
16 |
| - LoweredBackendModule, |
17 |
| - to_backend, |
18 |
| - to_backend_multiple, |
19 |
| -) |
| 13 | +from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend |
20 | 14 | from executorch.exir.backend.compile_spec_schema import CompileSpec
|
21 | 15 | from executorch.exir.backend.partitioner import (
|
22 | 16 | DelegationSpec,
|
@@ -1235,137 +1229,3 @@ def forward(self, x: List[torch.Tensor]):
|
1235 | 1229 |
|
1236 | 1230 | gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
|
1237 | 1231 | gm(*inputs)
|
1238 |
| - |
1239 |
| - def test_lower_multiple(self) -> None: |
1240 |
| - class MultipleMethodModule(torch.nn.Module): |
1241 |
| - def __init__(self) -> None: |
1242 |
| - super().__init__() |
1243 |
| - |
1244 |
| - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
1245 |
| - return x + y * y |
1246 |
| - |
1247 |
| - def method1(self, x: torch.Tensor) -> torch.Tensor: |
1248 |
| - return x + x - x |
1249 |
| - |
1250 |
| - def method2( |
1251 |
| - self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor |
1252 |
| - ) -> torch.Tensor: |
1253 |
| - return x + y - z |
1254 |
| - |
1255 |
| - module = MultipleMethodModule() |
1256 |
| - method_name_to_args = { |
1257 |
| - "forward": (torch.rand(2, 2), torch.rand(2, 2)), |
1258 |
| - "method1": (torch.rand(2, 2),), |
1259 |
| - "method2": (torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)), |
1260 |
| - } |
1261 |
| - |
1262 |
| - multi_method_prog = exir.capture_multiple( |
1263 |
| - module, method_name_to_args, exir.CaptureConfig() |
1264 |
| - ).to_edge() |
1265 |
| - |
1266 |
| - lowered_multi_method_prog = to_backend_multiple( |
1267 |
| - multi_method_prog, AddMulPartitionerDemo |
1268 |
| - ) |
1269 |
| - |
1270 |
| - for method_name, args in method_name_to_args.items(): |
1271 |
| - exported_prog = lowered_multi_method_prog.find_method(method_name) |
1272 |
| - self.assertIsNotNone(exported_prog) |
1273 |
| - exported_gm = exported_prog.exported_program.graph_module |
1274 |
| - self.assertIsInstance(exported_gm, fx.GraphModule) |
1275 |
| - |
1276 |
| - eager_method = getattr(module, method_name) |
1277 |
| - eager_results = eager_method(*args) |
1278 |
| - exported_results = exported_gm(*args) |
1279 |
| - self.assertTrue(torch.allclose(eager_results, exported_results[0])) |
1280 |
| - |
1281 |
| - add_nodes = [ |
1282 |
| - node |
1283 |
| - for node in exported_gm.graph.nodes |
1284 |
| - if node.op == "call_function" |
1285 |
| - and node.target == exir_ops.edge.aten.add.Tensor |
1286 |
| - ] |
1287 |
| - self.assertEqual(len(add_nodes), 0) |
1288 |
| - |
1289 |
| - lowered_submods = get_lowered_submodules(exported_gm) |
1290 |
| - self.assertEqual(len(lowered_submods), 1) |
1291 |
| - |
1292 |
| - _ = multi_method_program_to_executorch(lowered_multi_method_prog) |
1293 |
| - |
1294 |
| - def test_lower_multiple_selective(self) -> None: |
1295 |
| - class MultipleMethodModule(torch.nn.Module): |
1296 |
| - def __init__(self) -> None: |
1297 |
| - super().__init__() |
1298 |
| - |
1299 |
| - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
1300 |
| - return x + y * y |
1301 |
| - |
1302 |
| - def method1(self, x: torch.Tensor) -> torch.Tensor: |
1303 |
| - return x + x - x |
1304 |
| - |
1305 |
| - def method2( |
1306 |
| - self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor |
1307 |
| - ) -> torch.Tensor: |
1308 |
| - return x + y - z |
1309 |
| - |
1310 |
| - module = MultipleMethodModule() |
1311 |
| - method_name_to_args = { |
1312 |
| - "forward": (torch.rand(2, 2), torch.rand(2, 2)), |
1313 |
| - "method1": (torch.rand(2, 2),), |
1314 |
| - "method2": (torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)), |
1315 |
| - } |
1316 |
| - |
1317 |
| - multi_method_prog = exir.capture_multiple( |
1318 |
| - module, method_name_to_args, exir.CaptureConfig() |
1319 |
| - ).to_edge() |
1320 |
| - |
1321 |
| - method_name_to_partitioners = { |
1322 |
| - "forward": AddMulPartitionerDemo, |
1323 |
| - "method1": AddMulPartitionerDemo, |
1324 |
| - } |
1325 |
| - lowered_multi_method_prog = to_backend_multiple( |
1326 |
| - multi_method_prog, method_name_to_partitioners |
1327 |
| - ) |
1328 |
| - |
1329 |
| - for method_name, args in method_name_to_args.items(): |
1330 |
| - if method_name == "method2": |
1331 |
| - break |
1332 |
| - |
1333 |
| - exported_prog = lowered_multi_method_prog.find_method(method_name) |
1334 |
| - self.assertIsNotNone(exported_prog) |
1335 |
| - exported_gm = exported_prog.exported_program.graph_module |
1336 |
| - self.assertIsInstance(exported_gm, fx.GraphModule) |
1337 |
| - |
1338 |
| - eager_method = getattr(module, method_name) |
1339 |
| - eager_results = eager_method(*args) |
1340 |
| - exported_results = exported_gm(*args) |
1341 |
| - self.assertTrue(torch.allclose(eager_results, exported_results[0])) |
1342 |
| - |
1343 |
| - add_nodes = [ |
1344 |
| - node |
1345 |
| - for node in exported_gm.graph.nodes |
1346 |
| - if node.op == "call_function" |
1347 |
| - and node.target == exir_ops.edge.aten.add.Tensor |
1348 |
| - ] |
1349 |
| - self.assertEqual(len(add_nodes), 0) |
1350 |
| - |
1351 |
| - lowered_submods = get_lowered_submodules(exported_gm) |
1352 |
| - self.assertEqual(len(lowered_submods), 1) |
1353 |
| - |
1354 |
| - # Check that method2 had nothing lowered |
1355 |
| - method2_prog = lowered_multi_method_prog.find_method("method2") |
1356 |
| - self.assertIsNotNone(method2_prog) |
1357 |
| - method2_gm = method2_prog.exported_program.graph_module |
1358 |
| - self.assertIsInstance(method2_gm, fx.GraphModule) |
1359 |
| - add_nodes = [ |
1360 |
| - node |
1361 |
| - for node in method2_gm.graph.nodes |
1362 |
| - if node.op == "call_function" |
1363 |
| - and node.target == exir_ops.edge.aten.add.Tensor |
1364 |
| - ] |
1365 |
| - self.assertEqual(len(add_nodes), 1) |
1366 |
| - |
1367 |
| - lowered_submods = get_lowered_submodules(method2_gm) |
1368 |
| - self.assertEqual(len(lowered_submods), 0) |
1369 |
| - |
1370 |
| - # Check we can export to executorch properly |
1371 |
| - _ = multi_method_program_to_executorch(lowered_multi_method_prog) |
0 commit comments