Skip to content

Commit d010e62

Browse files
Varun Purifacebook-github-bot
authored andcommitted
Improvements to gen_sample_etrecord
Summary: Various improvements to the gen_sample_etrecord script. - Enable users to generate an etrecord from any of the sample models - Output files to /tmp instead of the working directory, provide a CLI option for this - Clean up imports and add type hints Differential Revision: D50579104 fbshipit-source-id: d40366e4381e6cc42cd566d03f84224c551b7370
1 parent 1133d45 commit d010e62

File tree

2 files changed

+49
-16
lines changed

2 files changed

+49
-16
lines changed

examples/sdk/scripts/etrecord.bin

15.5 KB
Binary file not shown.

examples/sdk/scripts/gen_sample_etrecord.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,38 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# Generate fixture files
8-
from pathlib import Path
8+
import argparse
99

10-
import executorch.exir as exir
11-
from executorch.exir import ExecutorchBackendConfig
12-
13-
from executorch.exir.tests.models import BasicSinMax
10+
import torch
11+
from executorch.exir import (
12+
EdgeCompileConfig,
13+
EdgeProgramManager,
14+
ExecutorchProgramManager,
15+
ExportedProgram,
16+
to_edge,
17+
)
1418
from executorch.sdk.etrecord import generate_etrecord
1519
from torch.export import export
1620

21+
from ...models import MODEL_NAME_TO_MODEL
22+
from ...models.model_factory import EagerModelFactory
23+
1724

18-
def get_module_path() -> Path:
19-
return Path(__file__).resolve().parents[0]
25+
DEFAULT_OUTPUT_PATH = "/tmp/etrecord.bin"
2026

2127

22-
def gen_etrecord():
23-
f = BasicSinMax()
24-
aten_dialect = export(
28+
def gen_etrecord(model: torch.nn.Module, inputs, output_path=None):
29+
f = model
30+
aten_dialect: ExportedProgram = export(
2531
f,
26-
f.get_random_inputs(),
32+
inputs,
2733
)
28-
edge_program = exir.to_edge(
29-
aten_dialect, compile_config=exir.EdgeCompileConfig(_check_ir_validity=False)
34+
edge_program: EdgeProgramManager = to_edge(
35+
aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=True)
3036
)
31-
et_program = edge_program.to_executorch(ExecutorchBackendConfig(passes=[]))
37+
et_program: ExecutorchProgramManager = edge_program.to_executorch()
3238
generate_etrecord(
33-
str(get_module_path()) + "/etrecord.bin",
39+
(DEFAULT_OUTPUT_PATH if not output_path else output_path),
3440
edge_dialect_program=edge_program,
3541
executorch_program=et_program,
3642
export_modules={
@@ -40,4 +46,31 @@ def gen_etrecord():
4046

4147

4248
if __name__ == "__main__":
43-
gen_etrecord()
49+
parser = argparse.ArgumentParser()
50+
parser.add_argument(
51+
"-m",
52+
"--model_name",
53+
required=True,
54+
help=f"provide a model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}",
55+
)
56+
57+
parser.add_argument(
58+
"-o",
59+
"--output_path",
60+
required=False,
61+
help=f"Provide an output path to save the generated etrecord. Defaults to {DEFAULT_OUTPUT_PATH}.",
62+
)
63+
64+
args = parser.parse_args()
65+
66+
if args.model_name not in MODEL_NAME_TO_MODEL:
67+
raise RuntimeError(
68+
f"Model {args.model_name} is not a valid name. "
69+
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
70+
)
71+
72+
model, example_inputs = EagerModelFactory.create_model(
73+
*MODEL_NAME_TO_MODEL[args.model_name]
74+
)
75+
76+
gen_etrecord(model, example_inputs, args.output_path)

0 commit comments

Comments
 (0)