Skip to content

Commit f370e78

Browse files
authored
Added Cadence memory planning example
Differential Revision: D65240495 Pull Request resolved: #7418
1 parent 82763a9 commit f370e78

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ python_library(
4040
":utils",
4141
":ops_registrations",
4242
":replace_ops",
43+
":memory_planning",
4344
"//caffe2:torch",
4445
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
4546
"//executorch/backends/cadence/aot/quantizer:quantizer",

backends/cadence/aot/compiler.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,34 @@
1212

1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
15+
from executorch.backends.cadence.aot.memory_planning import (
16+
CadenceMemoryPlanning,
17+
print_memory_planning_info,
18+
)
1519
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
1620
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
1721

1822
from executorch.backends.cadence.aot.replace_ops import ReplaceSafeSoftmaxWithSoftmax
19-
from executorch.backends.cadence.aot.utils import model_gm_has_SDPA, model_is_quantized
23+
from executorch.backends.cadence.aot.utils import (
24+
get_default_memory_config,
25+
MemoryConfig,
26+
model_gm_has_SDPA,
27+
model_is_quantized,
28+
)
2029
from executorch.backends.transforms.decompose_sdpa import (
2130
DecomposeScaledDotProductAttention,
2231
)
2332
from executorch.devtools import generate_etrecord
2433
from executorch.exir import (
2534
EdgeCompileConfig,
2635
EdgeProgramManager,
36+
ExecutorchBackendConfig,
2737
ExecutorchProgramManager,
2838
to_edge,
2939
)
3040
from executorch.exir.pass_base import PassResult
41+
from executorch.exir.passes import ToOutVarPass
42+
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
3143
from torch._inductor.decomposition import remove_decompositions
3244
from torch.ao.quantization.pt2e.export_utils import model_is_exported
3345
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
@@ -263,6 +275,10 @@ def export_to_executorch_gen_etrecord(
263275
inputs: tuple[object, ...],
264276
output_dir: Optional[str] = None,
265277
opt_level: int = 1,
278+
mem_algo: int = 0,
279+
alloc_graph_input: bool = True,
280+
alloc_graph_output: bool = True,
281+
memory_config: Optional[MemoryConfig] = None,
266282
dump_graphs: bool = False,
267283
) -> ExecutorchProgramManager:
268284
cadence_passes = get_cadence_passes(opt_level)
@@ -281,8 +297,36 @@ def export_to_executorch_gen_etrecord(
281297
cadence_prog_manager.exported_program().graph_module,
282298
)
283299

300+
if memory_config is None:
301+
memory_config = get_default_memory_config()
302+
303+
memory_planning_pass = CadenceMemoryPlanning(
304+
memory_config,
305+
opt_level=opt_level,
306+
mem_algo=mem_algo,
307+
alloc_graph_input=alloc_graph_input,
308+
alloc_graph_output=alloc_graph_output,
309+
)
310+
284311
# Get executorch program after Cadence specific passes
285-
exec_prog: ExecutorchProgramManager = cadence_prog_manager.to_executorch()
312+
exec_prog: ExecutorchProgramManager = cadence_prog_manager.to_executorch(
313+
ExecutorchBackendConfig(
314+
memory_planning_pass=memory_planning_pass,
315+
emit_stacktrace=False,
316+
to_out_var_pass=ToOutVarPass(),
317+
extract_delegate_segments=False,
318+
sym_shape_eval_pass=HintBasedSymShapeEvalPass(),
319+
),
320+
)
321+
322+
print_memory_planning_info(
323+
exec_prog,
324+
memory_config,
325+
opt_level,
326+
alloc_graph_input,
327+
alloc_graph_output,
328+
)
329+
286330
if output_dir:
287331
_gen_etrecord(edge_prog_manager, exec_prog, Path(output_dir))
288332
else:

0 commit comments

Comments
 (0)