Skip to content

Save the representative intputs into the ETRecord object #11244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions devtools/etrecord/_etrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from executorch.exir.serde.export_serialize import SerializedArtifact
from executorch.exir.serde.serialize import deserialize, serialize

ProgramInput = List[Value]
ProgramOutput = List[Value]

try:
Expand All @@ -49,6 +50,7 @@ class ETRecordReservedFileNames(StrEnum):
DEBUG_HANDLE_MAP_NAME = "debug_handle_map"
DELEGATE_MAP_NAME = "delegate_map"
REFERENCE_OUTPUTS = "reference_outputs"
REPRESENTATIVE_INPUTS = "representative_inputs"


@dataclass
Expand All @@ -60,6 +62,7 @@ class ETRecord:
Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]
] = None
_reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None
_representative_inputs: Optional[List[ProgramOutput]] = None


def _handle_exported_program(
Expand Down Expand Up @@ -157,6 +160,24 @@ def _get_reference_outputs(
return reference_outputs


def _get_representative_inputs(
bundled_program: BundledProgram,
) -> List[ProgramInput]:
"""
Extracts out the inputs from the bundled program, keyed by the method names.
"""
for method_test_suite in bundled_program.method_test_suites:
if method_test_suite.method_name == "forward":
if not method_test_suite.test_cases:
raise ValueError(
"The 'forward' method is defined, but no corresponding input test cases are provided."
)
# Get first example input from the forward method
test_case = method_test_suite.test_cases[0]
return test_case.inputs
raise ValueError("No 'forward' method found in the bundled program.")


def generate_etrecord(
et_record: Union[str, os.PathLike, BinaryIO, IO[bytes]],
edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram],
Expand Down Expand Up @@ -244,6 +265,13 @@ def generate_etrecord(
# @lint-ignore PYTHONPICKLEISBAD
pickle.dumps(reference_outputs),
)

representative_inputs = _get_representative_inputs(executorch_program)
etrecord_zip.writestr(
ETRecordReservedFileNames.REPRESENTATIVE_INPUTS,
# @lint-ignore PYTHONPICKLEISBAD
pickle.dumps(representative_inputs),
)
executorch_program = executorch_program.executorch_program

etrecord_zip.writestr(
Expand Down Expand Up @@ -290,6 +318,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
delegate_map = None
edge_dialect_program = None
reference_outputs = None
representative_inputs = None

serialized_exported_program_files = set()
serialized_state_dict_files = set()
Expand Down Expand Up @@ -321,6 +350,11 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
reference_outputs = pickle.loads(
etrecord_zip.read(ETRecordReservedFileNames.REFERENCE_OUTPUTS)
)
elif entry == ETRecordReservedFileNames.REPRESENTATIVE_INPUTS:
# @lint-ignore PYTHONPICKLEISBAD
representative_inputs = pickle.loads(
etrecord_zip.read(ETRecordReservedFileNames.REPRESENTATIVE_INPUTS)
)
else:
if entry.endswith("state_dict"):
serialized_state_dict_files.add(entry)
Expand Down Expand Up @@ -352,4 +386,5 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
_debug_handle_map=debug_handle_map,
_delegate_map=delegate_map,
_reference_outputs=reference_outputs,
_representative_inputs=representative_inputs,
)
19 changes: 15 additions & 4 deletions devtools/etrecord/tests/etrecord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from executorch.devtools.etrecord import generate_etrecord, parse_etrecord
from executorch.devtools.etrecord._etrecord import (
_get_reference_outputs,
_get_representative_inputs,
ETRecordReservedFileNames,
)
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
Expand Down Expand Up @@ -135,15 +136,25 @@ def test_etrecord_generation_with_bundled_program(self):
)
etrecord = parse_etrecord(tmpdirname + "/etrecord.bin")

expected = etrecord._reference_outputs
actual = _get_reference_outputs(bundled_program)
expected_inputs = etrecord._representative_inputs
actual_inputs = _get_representative_inputs(bundled_program)
# assertEqual() gives "RuntimeError: Boolean value of Tensor with more than one value is ambiguous" when comparing tensors,
# so we use torch.equal() to compare the tensors one by one.
for expected, actual in zip(expected_inputs, actual_inputs):
self.assertTrue(torch.equal(expected[0], actual[0]))
self.assertTrue(torch.equal(expected[1], actual[1]))

expected_outputs = etrecord._reference_outputs
actual_outputs = _get_reference_outputs(bundled_program)
self.assertTrue(
torch.equal(expected["forward"][0][0], actual["forward"][0][0])
torch.equal(
expected_outputs["forward"][0][0], actual_outputs["forward"][0][0]
)
)
self.assertTrue(
torch.equal(expected["forward"][1][0], actual["forward"][1][0])
torch.equal(
expected_outputs["forward"][1][0], actual_outputs["forward"][1][0]
)
)

def test_etrecord_generation_with_manager(self):
Expand Down
Loading