5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
# Generate fixture files
8
- from pathlib import Path
8
+ import argparse
9
+ import copy
10
+ from typing import Any
9
11
10
- import executorch .exir as exir
11
- from executorch .exir import ExecutorchBackendConfig
12
-
13
- from executorch .exir .tests .models import BasicSinMax
12
+ import torch
13
+ from executorch .exir import (
14
+ EdgeCompileConfig ,
15
+ EdgeProgramManager ,
16
+ ExecutorchProgramManager ,
17
+ ExportedProgram ,
18
+ to_edge ,
19
+ )
14
20
from executorch .sdk .etrecord import generate_etrecord
15
21
from torch .export import export
16
22
23
+ from ...models import MODEL_NAME_TO_MODEL
24
+ from ...models .model_factory import EagerModelFactory
25
+
17
26
18
- def get_module_path () -> Path :
19
- return Path (__file__ ).resolve ().parents [0 ]
27
+ DEFAULT_OUTPUT_PATH = "/tmp/etrecord.bin"
20
28
21
29
22
- def gen_etrecord ():
23
- f = BasicSinMax ()
24
- aten_dialect = export (
30
+ def gen_etrecord (model : torch . nn . Module , inputs : Any , output_path = None ):
31
+ f = model
32
+ aten_dialect : ExportedProgram = export (
25
33
f ,
26
- f . get_random_inputs () ,
34
+ inputs ,
27
35
)
28
- edge_program = exir . to_edge (
29
- aten_dialect , compile_config = exir . EdgeCompileConfig (_check_ir_validity = False )
36
+ edge_program : EdgeProgramManager = to_edge (
37
+ aten_dialect , compile_config = EdgeCompileConfig (_check_ir_validity = True )
30
38
)
31
- et_program = edge_program .to_executorch (ExecutorchBackendConfig (passes = []))
39
+ edge_program_copy = copy .deepcopy (edge_program )
40
+ et_program : ExecutorchProgramManager = edge_program_copy .to_executorch ()
32
41
generate_etrecord (
33
- str ( get_module_path ()) + "/etrecord.bin" ,
42
+ ( DEFAULT_OUTPUT_PATH if not output_path else output_path ) ,
34
43
edge_dialect_program = edge_program ,
35
44
executorch_program = et_program ,
36
45
export_modules = {
@@ -40,4 +49,31 @@ def gen_etrecord():
40
49
41
50
42
51
if __name__ == "__main__" :
43
- gen_etrecord ()
52
+ parser = argparse .ArgumentParser ()
53
+ parser .add_argument (
54
+ "-m" ,
55
+ "--model_name" ,
56
+ required = True ,
57
+ help = f"provide a model name. Valid ones: { list (MODEL_NAME_TO_MODEL .keys ())} " ,
58
+ )
59
+
60
+ parser .add_argument (
61
+ "-o" ,
62
+ "--output_path" ,
63
+ required = False ,
64
+ help = f"Provide an output path to save the generated etrecord. Defaults to { DEFAULT_OUTPUT_PATH } ." ,
65
+ )
66
+
67
+ args = parser .parse_args ()
68
+
69
+ if args .model_name not in MODEL_NAME_TO_MODEL :
70
+ raise RuntimeError (
71
+ f"Model { args .model_name } is not a valid name. "
72
+ f"Available models are { list (MODEL_NAME_TO_MODEL .keys ())} ."
73
+ )
74
+
75
+ model , example_inputs = EagerModelFactory .create_model (
76
+ * MODEL_NAME_TO_MODEL [args .model_name ]
77
+ )
78
+
79
+ gen_etrecord (model , example_inputs , args .output_path )
0 commit comments