12
12
13
13
import executorch .backends .cadence .aot .ops_registrations # noqa
14
14
import torch
15
+ from executorch .backends .cadence .aot .memory_planning import (
16
+ CadenceMemoryPlanning ,
17
+ print_memory_planning_info ,
18
+ )
15
19
from executorch .backends .cadence .aot .quantizer .fusion_pass import QuantFusion
16
20
from executorch .backends .cadence .aot .quantizer .quantizer import CadenceQuantizer
17
21
18
22
from executorch .backends .cadence .aot .replace_ops import ReplaceSafeSoftmaxWithSoftmax
19
- from executorch .backends .cadence .aot .utils import model_gm_has_SDPA , model_is_quantized
23
+ from executorch .backends .cadence .aot .utils import (
24
+ get_default_memory_config ,
25
+ MemoryConfig ,
26
+ model_gm_has_SDPA ,
27
+ model_is_quantized ,
28
+ )
20
29
from executorch .backends .transforms .decompose_sdpa import (
21
30
DecomposeScaledDotProductAttention ,
22
31
)
23
32
from executorch .devtools import generate_etrecord
24
33
from executorch .exir import (
25
34
EdgeCompileConfig ,
26
35
EdgeProgramManager ,
36
+ ExecutorchBackendConfig ,
27
37
ExecutorchProgramManager ,
28
38
to_edge ,
29
39
)
30
40
from executorch .exir .pass_base import PassResult
41
+ from executorch .exir .passes import ToOutVarPass
42
+ from executorch .exir .passes .sym_shape_eval_pass import HintBasedSymShapeEvalPass
31
43
from torch ._inductor .decomposition import remove_decompositions
32
44
from torch .ao .quantization .pt2e .export_utils import model_is_exported
33
45
from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
@@ -263,6 +275,10 @@ def export_to_executorch_gen_etrecord(
263
275
inputs : tuple [object , ...],
264
276
output_dir : Optional [str ] = None ,
265
277
opt_level : int = 1 ,
278
+ mem_algo : int = 0 ,
279
+ alloc_graph_input : bool = True ,
280
+ alloc_graph_output : bool = True ,
281
+ memory_config : Optional [MemoryConfig ] = None ,
266
282
dump_graphs : bool = False ,
267
283
) -> ExecutorchProgramManager :
268
284
cadence_passes = get_cadence_passes (opt_level )
@@ -281,8 +297,36 @@ def export_to_executorch_gen_etrecord(
281
297
cadence_prog_manager .exported_program ().graph_module ,
282
298
)
283
299
300
+ if memory_config is None :
301
+ memory_config = get_default_memory_config ()
302
+
303
+ memory_planning_pass = CadenceMemoryPlanning (
304
+ memory_config ,
305
+ opt_level = opt_level ,
306
+ mem_algo = mem_algo ,
307
+ alloc_graph_input = alloc_graph_input ,
308
+ alloc_graph_output = alloc_graph_output ,
309
+ )
310
+
284
311
# Get executorch program after Cadence specific passes
285
- exec_prog : ExecutorchProgramManager = cadence_prog_manager .to_executorch ()
312
+ exec_prog : ExecutorchProgramManager = cadence_prog_manager .to_executorch (
313
+ ExecutorchBackendConfig (
314
+ memory_planning_pass = memory_planning_pass ,
315
+ emit_stacktrace = False ,
316
+ to_out_var_pass = ToOutVarPass (),
317
+ extract_delegate_segments = False ,
318
+ sym_shape_eval_pass = HintBasedSymShapeEvalPass (),
319
+ ),
320
+ )
321
+
322
+ print_memory_planning_info (
323
+ exec_prog ,
324
+ memory_config ,
325
+ opt_level ,
326
+ alloc_graph_input ,
327
+ alloc_graph_output ,
328
+ )
329
+
286
330
if output_dir :
287
331
_gen_etrecord (edge_prog_manager , exec_prog , Path (output_dir ))
288
332
else :
0 commit comments