Skip to content

Commit 7814dd7

Browse files
Varun Purifacebook-github-bot
authored andcommitted
Improvements to gen_sample_etrecord (#1073)
Summary: Pull Request resolved: #1073 Various improvements to the gen_sample_etrecord script. - Enable users to generate an etrecord from any of the sample models - Output files to /tmp instead of the working directory, provide a CLI option for this - Clean up imports and add type hints Reviewed By: Jack-Khuu Differential Revision: D50579104 fbshipit-source-id: 7543f274b0ce9ee96a63fc4f8e8e9bdddebabfd5
1 parent 1f7ad33 commit 7814dd7

File tree

3 files changed

+56
-17
lines changed

3 files changed

+56
-17
lines changed

examples/sdk/scripts/etrecord.bin

15.5 KB
Binary file not shown.

examples/sdk/scripts/gen_sample_etrecord.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,41 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# Generate fixture files
8-
from pathlib import Path
8+
import argparse
9+
import copy
10+
from typing import Any
911

10-
import executorch.exir as exir
11-
from executorch.exir import ExecutorchBackendConfig
12-
13-
from executorch.exir.tests.models import BasicSinMax
12+
import torch
13+
from executorch.exir import (
14+
EdgeCompileConfig,
15+
EdgeProgramManager,
16+
ExecutorchProgramManager,
17+
ExportedProgram,
18+
to_edge,
19+
)
1420
from executorch.sdk.etrecord import generate_etrecord
1521
from torch.export import export
1622

23+
from ...models import MODEL_NAME_TO_MODEL
24+
from ...models.model_factory import EagerModelFactory
25+
1726

18-
def get_module_path() -> Path:
19-
return Path(__file__).resolve().parents[0]
27+
DEFAULT_OUTPUT_PATH = "/tmp/etrecord.bin"
2028

2129

22-
def gen_etrecord():
23-
f = BasicSinMax()
24-
aten_dialect = export(
30+
def gen_etrecord(model: torch.nn.Module, inputs: Any, output_path=None):
31+
f = model
32+
aten_dialect: ExportedProgram = export(
2533
f,
26-
f.get_random_inputs(),
34+
inputs,
2735
)
28-
edge_program = exir.to_edge(
29-
aten_dialect, compile_config=exir.EdgeCompileConfig(_check_ir_validity=False)
36+
edge_program: EdgeProgramManager = to_edge(
37+
aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=True)
3038
)
31-
et_program = edge_program.to_executorch(ExecutorchBackendConfig(passes=[]))
39+
edge_program_copy = copy.deepcopy(edge_program)
40+
et_program: ExecutorchProgramManager = edge_program_copy.to_executorch()
3241
generate_etrecord(
33-
str(get_module_path()) + "/etrecord.bin",
42+
(DEFAULT_OUTPUT_PATH if not output_path else output_path),
3443
edge_dialect_program=edge_program,
3544
executorch_program=et_program,
3645
export_modules={
@@ -40,4 +49,31 @@ def gen_etrecord():
4049

4150

4251
if __name__ == "__main__":
43-
gen_etrecord()
52+
parser = argparse.ArgumentParser()
53+
parser.add_argument(
54+
"-m",
55+
"--model_name",
56+
required=True,
57+
help=f"provide a model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}",
58+
)
59+
60+
parser.add_argument(
61+
"-o",
62+
"--output_path",
63+
required=False,
64+
help=f"Provide an output path to save the generated etrecord. Defaults to {DEFAULT_OUTPUT_PATH}.",
65+
)
66+
67+
args = parser.parse_args()
68+
69+
if args.model_name not in MODEL_NAME_TO_MODEL:
70+
raise RuntimeError(
71+
f"Model {args.model_name} is not a valid name. "
72+
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
73+
)
74+
75+
model, example_inputs = EagerModelFactory.create_model(
76+
*MODEL_NAME_TO_MODEL[args.model_name]
77+
)
78+
79+
gen_etrecord(model, example_inputs, args.output_path)

sdk/etrecord/_etrecord.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,10 @@ def generate_etrecord(
124124
Dict[
125125
str,
126126
Union[
127-
MultiMethodExirExportedProgram, ExirExportedProgram, EdgeProgramManager
127+
ExportedProgram,
128+
MultiMethodExirExportedProgram,
129+
ExirExportedProgram,
130+
EdgeProgramManager,
128131
],
129132
]
130133
] = None,

0 commit comments

Comments
 (0)