|
6 | 6 |
|
7 | 7 | # pyre-unsafe
|
8 | 8 |
|
| 9 | +import copy |
9 | 10 | import random
|
10 | 11 | import statistics
|
11 | 12 | import tempfile
|
|
16 | 17 |
|
17 | 18 | from unittest.mock import patch
|
18 | 19 |
|
| 20 | +import torch |
| 21 | +import torch.fx |
| 22 | + |
19 | 23 | from executorch.devtools import generate_etrecord, parse_etrecord
|
20 | 24 | from executorch.devtools.debug_format.et_schema import OperatorNode
|
21 | 25 | from executorch.devtools.etdump.schema_flatcc import ProfileEvent
|
| 26 | +from executorch.devtools.etrecord._etrecord import ETRecord |
22 | 27 | from executorch.devtools.etrecord.tests.etrecord_test import TestETRecord
|
23 | 28 |
|
24 | 29 | from executorch.devtools.inspector import (
|
|
36 | 41 | ProfileEventSignature,
|
37 | 42 | TimeScale,
|
38 | 43 | )
|
39 |
| - |
40 |
| -from executorch.exir import ExportedProgram |
| 44 | +from executorch.devtools.inspector.tests.inspector_test_utils import ( |
| 45 | + check_if_final_outputs_match, |
| 46 | + model_registry, |
| 47 | +) |
| 48 | +from executorch.exir import ( |
| 49 | + EdgeCompileConfig, |
| 50 | + EdgeProgramManager, |
| 51 | + ExecutorchProgramManager, |
| 52 | + to_edge, |
| 53 | +) |
| 54 | +from torch.export import export, ExportedProgram |
41 | 55 |
|
42 | 56 |
|
43 | 57 | OP_TYPE = "aten::add"
|
@@ -452,6 +466,75 @@ def test_populate_debugging_related_fields_passes_for_consistent_events(self):
|
452 | 466 | events=events,
|
453 | 467 | )
|
454 | 468 |
|
| 469 | + def test_no_capture_when_representative_inputs_are_none(self): |
| 470 | + # Create a context manager to patch functions called by Inspector.__init__ |
| 471 | + with patch.object( |
| 472 | + _inspector, "parse_etrecord", return_value=None |
| 473 | + ), patch.object( |
| 474 | + _inspector, "gen_etdump_object", return_value=None |
| 475 | + ), patch.object( |
| 476 | + EventBlock, "_gen_from_etdump" |
| 477 | + ), patch.object( |
| 478 | + _inspector, "gen_graphs_from_etrecord" |
| 479 | + ): |
| 480 | + # Call the constructor of Inspector |
| 481 | + inspector_instance = Inspector( |
| 482 | + etdump_path=ETDUMP_PATH, |
| 483 | + etrecord=ETRECORD_PATH, |
| 484 | + ) |
| 485 | + self.assertIsNone(inspector_instance._aot_intermediate_outputs) |
| 486 | + |
| 487 | + def test_consume_etrecord_populates_correct_aot_intermediate_outputs(self): |
| 488 | + with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file: |
| 489 | + etrecord_path = tmp_file.name |
| 490 | + mod = model_registry["ConvLinearModel"]() |
| 491 | + input_tensor = torch.tensor( |
| 492 | + [[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True |
| 493 | + ) |
| 494 | + aten_model: ExportedProgram = export(mod, (input_tensor,), strict=True) |
| 495 | + edge_program_manager: EdgeProgramManager = to_edge( |
| 496 | + aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True) |
| 497 | + ) |
| 498 | + edge_program_manager_copy = copy.deepcopy(edge_program_manager) |
| 499 | + et_program_manager: ExecutorchProgramManager = ( |
| 500 | + edge_program_manager.to_executorch() |
| 501 | + ) |
| 502 | + # Generate ETRecord |
| 503 | + generate_etrecord( |
| 504 | + etrecord_path, edge_program_manager_copy, et_program_manager |
| 505 | + ) |
| 506 | + original_consume_etrecord = Inspector._consume_etrecord |
| 507 | + with patch.object( |
| 508 | + Inspector, "_consume_etrecord", return_value=None |
| 509 | + ), patch.object( |
| 510 | + _inspector, "gen_etdump_object", return_value=None |
| 511 | + ), patch.object( |
| 512 | + EventBlock, "_gen_from_etdump" |
| 513 | + ), patch.object( |
| 514 | + _inspector, "gen_graphs_from_etrecord" |
| 515 | + ): |
| 516 | + # Call the constructor of Inspector |
| 517 | + inspector_instance = Inspector( |
| 518 | + etdump_path=ETDUMP_PATH, |
| 519 | + etrecord=etrecord_path, |
| 520 | + ) |
| 521 | + etrecord = ETRecord( |
| 522 | + edge_dialect_program=inspector_instance._etrecord.edge_dialect_program, |
| 523 | + graph_map=inspector_instance._etrecord.graph_map, |
| 524 | + _debug_handle_map=inspector_instance._etrecord._debug_handle_map, |
| 525 | + _delegate_map=inspector_instance._etrecord._delegate_map, |
| 526 | + _reference_outputs=inspector_instance._etrecord._reference_outputs, |
| 527 | + _representative_inputs=aten_model.example_inputs[0], |
| 528 | + ) |
| 529 | + inspector_instance._etrecord = etrecord |
| 530 | + Inspector._consume_etrecord = original_consume_etrecord |
| 531 | + inspector_instance._consume_etrecord() |
| 532 | + self.assertTrue( |
| 533 | + check_if_final_outputs_match( |
| 534 | + "ConvLinearModel", inspector_instance._aot_intermediate_outputs |
| 535 | + ) |
| 536 | + ) |
| 537 | + |
455 | 538 | def _gen_random_float_list(self) -> List[float]:
|
456 | 539 | return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]
|
457 | 540 |
|
|
0 commit comments