Skip to content

Commit 7fec943

Browse files
Varun Purifacebook-github-bot
authored andcommitted
Improvements to gen_sample_etrecord (#1073)
Summary: Pull Request resolved: #1073 Various improvements to the gen_sample_etrecord script. - Enable users to generate an etrecord from any of the sample models - Output files to /tmp instead of the working directory, provide a CLI option for this - Clean up imports and add type hints Differential Revision: D50579104 fbshipit-source-id: 40ee8297dba3bbe6f0bada5e819863b08e77862d
1 parent eead155 commit 7fec943

File tree

3 files changed

+52
-18
lines changed

3 files changed

+52
-18
lines changed

examples/sdk/scripts/etrecord.bin

15.5 KB
Binary file not shown.

examples/sdk/scripts/gen_sample_etrecord.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,39 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# Generate fixture files
8-
from pathlib import Path
9-
10-
import executorch.exir as exir
11-
from executorch.exir import ExecutorchBackendConfig
12-
13-
from executorch.exir.tests.models import BasicSinMax
8+
import argparse
9+
import copy
10+
import torch
11+
from executorch.exir import (
12+
EdgeCompileConfig,
13+
EdgeProgramManager,
14+
ExecutorchProgramManager,
15+
ExportedProgram,
16+
to_edge,
17+
)
1418
from executorch.sdk.etrecord import generate_etrecord
1519
from torch.export import export
1620

21+
from ...models import MODEL_NAME_TO_MODEL
22+
from ...models.model_factory import EagerModelFactory
23+
1724

18-
def get_module_path() -> Path:
19-
return Path(__file__).resolve().parents[0]
25+
DEFAULT_OUTPUT_PATH = "/tmp/etrecord.bin"
2026

2127

22-
def gen_etrecord():
23-
f = BasicSinMax()
24-
aten_dialect = export(
28+
def gen_etrecord(model: torch.nn.Module, inputs, output_path=None):
29+
f = model
30+
aten_dialect: ExportedProgram = export(
2531
f,
26-
f.get_random_inputs(),
32+
inputs,
2733
)
28-
edge_program = exir.to_edge(
29-
aten_dialect, compile_config=exir.EdgeCompileConfig(_check_ir_validity=False)
34+
edge_program: EdgeProgramManager = to_edge(
35+
aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=True)
3036
)
31-
et_program = edge_program.to_executorch(ExecutorchBackendConfig(passes=[]))
37+
edge_program_copy = copy.deepcopy(edge_program)
38+
et_program: ExecutorchProgramManager = edge_program_copy.to_executorch()
3239
generate_etrecord(
33-
str(get_module_path()) + "/etrecord.bin",
40+
(DEFAULT_OUTPUT_PATH if not output_path else output_path),
3441
edge_dialect_program=edge_program,
3542
executorch_program=et_program,
3643
export_modules={
@@ -40,4 +47,31 @@ def gen_etrecord():
4047

4148

4249
if __name__ == "__main__":
43-
gen_etrecord()
50+
parser = argparse.ArgumentParser()
51+
parser.add_argument(
52+
"-m",
53+
"--model_name",
54+
required=True,
55+
help=f"provide a model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}",
56+
)
57+
58+
parser.add_argument(
59+
"-o",
60+
"--output_path",
61+
required=False,
62+
help=f"Provide an output path to save the generated etrecord. Defaults to {DEFAULT_OUTPUT_PATH}.",
63+
)
64+
65+
args = parser.parse_args()
66+
67+
if args.model_name not in MODEL_NAME_TO_MODEL:
68+
raise RuntimeError(
69+
f"Model {args.model_name} is not a valid name. "
70+
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
71+
)
72+
73+
model, example_inputs = EagerModelFactory.create_model(
74+
*MODEL_NAME_TO_MODEL[args.model_name]
75+
)
76+
77+
gen_etrecord(model, example_inputs, args.output_path)

sdk/etrecord/_etrecord.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def generate_etrecord(
124124
Dict[
125125
str,
126126
Union[
127-
MultiMethodExirExportedProgram, ExirExportedProgram, EdgeProgramManager
127+
ExportedProgram, MultiMethodExirExportedProgram, ExirExportedProgram, EdgeProgramManager
128128
],
129129
]
130130
] = None,

0 commit comments

Comments
 (0)