Skip to content

Commit d9481d7

Browse files
Eashan Gargfacebook-github-bot
authored andcommitted
Add option to enforce alignment constraint when planning memory
Summary: Add ability to enforce start alignment in Cadence Memory Planner Differential Revision: D68762973
1 parent 6cce750 commit d9481d7

File tree

3 files changed

+57
-4
lines changed

3 files changed

+57
-4
lines changed

backends/cadence/aot/compiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def export_to_executorch_gen_etrecord(
257257
alloc_graph_output: bool = True,
258258
memory_config: Optional[MemoryConfig] = None,
259259
dump_graphs: bool = False,
260+
mem_alignment: int = 0,
260261
) -> ExecutorchProgramManager:
261262
cadence_passes = get_cadence_passes(opt_level)
262263
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)
@@ -283,6 +284,7 @@ def export_to_executorch_gen_etrecord(
283284
mem_algo=mem_algo,
284285
alloc_graph_input=alloc_graph_input,
285286
alloc_graph_output=alloc_graph_output,
287+
mem_alignment=mem_alignment,
286288
)
287289

288290
# Get executorch program after Cadence specific passes

backends/cadence/aot/memory_planning.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import collections
1010
import itertools
1111
import logging
12+
import math
1213
import typing
1314
from functools import partial
1415
from typing import Iterable, List, Optional, Tuple
@@ -39,6 +40,12 @@ def get_size(memory_config: MemoryConfig, exir_id: int) -> int:
3940
return memory_config.memory_sizes[exir_id - 1]
4041

4142

43+
def get_aligned_offset(pre_aligned_offset: int, alignment: int) -> int:
44+
if alignment == 0:
45+
return pre_aligned_offset
46+
return int(math.ceil(pre_aligned_offset / alignment) * alignment)
47+
48+
4249
def collect_specs_from_graph_module(
4350
graph_module: torch.fx.GraphModule,
4451
alloc_graph_input: bool,
@@ -95,7 +102,7 @@ def overlap(spec: TensorSpec) -> Optional[TensorSpec]:
95102
return None
96103

97104
def memory_available(spec: TensorSpec) -> bool:
98-
return spec.mem_offset + spec.allocated_memory <= get_size(
105+
return get_aligned_offset(spec.mem_offset + spec.allocated_memory, alignment) <= get_size(
99106
memory_config, spec.mem_id
100107
)
101108

@@ -116,7 +123,7 @@ def memory_available(spec: TensorSpec) -> bool:
116123
continue
117124
spec.mem_offset = 0
118125
while memory_available(spec) and (overlapped := overlap(spec)):
119-
spec.mem_offset = overlapped.mem_offset + overlapped.allocated_memory
126+
spec.mem_offset = get_aligned_offset(overlapped.mem_offset + overlapped.allocated_memory, alignment)
120127
if memory_available(spec):
121128
allocated_buffers[spec.mem_id].append(spec)
122129
bufsizes[spec.mem_id] = max(
@@ -202,11 +209,11 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
202209
# calculation of gap incorrect. Moving it out will make the algorithm degenerate
203210
# to the naive one, reusing 0 tensor. The paper may have a typo here.
204211
prev_offset = max(
205-
allocated_spec.mem_offset + allocated_spec.allocated_memory,
212+
get_aligned_offset(allocated_spec.mem_offset + allocated_spec.allocated_memory, alignment),
206213
prev_offset,
207214
)
208215
if spec.mem_offset is None:
209-
if prev_offset + spec.allocated_memory > get_size(
216+
if get_aligned_offset(prev_offset + spec.allocated_memory, alignment) > get_size(
210217
memory_config, spec.mem_id
211218
):
212219
continue
@@ -423,6 +430,7 @@ def __init__(
423430
]
424431
]
425432
] = None,
433+
mem_alignment: int = 0,
426434
) -> None:
427435
self._init_mem_algos()
428436

@@ -432,6 +440,7 @@ def __init__(
432440
self.alloc_graph_input = alloc_graph_input
433441
self.alloc_graph_output = alloc_graph_output
434442
self.additional_constraint_gen_passes = additional_constraint_gen_passes
443+
self.mem_alignment = mem_alignment
435444

436445
def _init_mem_algos(self) -> None:
437446
self.available_mem_algos = [
@@ -459,6 +468,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
459468
allow_lifetime_and_storage_overlap=(self.opt_level >= 2),
460469
alloc_graph_input=self.alloc_graph_input,
461470
alloc_graph_output=self.alloc_graph_output,
471+
alignment=self.mem_alignment,
462472
)
463473
mem_planning(graph_module)
464474

backends/cadence/aot/tests/test_memory_passes.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,14 @@
1010
from executorch.backends.cadence.aot import compiler
1111
from executorch.backends.cadence.aot.memory_planning import find_peak_memory_usage
1212
from executorch.backends.cadence.aot.pass_utils import count_node
13+
from executorch.backends.cadence.aot.utils import (
14+
get_default_memory_config,
15+
MemoryConfig,
16+
)
1317
from executorch.exir import memory
1418
from executorch.exir.dialects._ops import ops as exir_ops
1519
from executorch.exir.tests.models import MultiLayerPerceptron
20+
from executorch.exir.memory_planning import collect_specs_from_nodes
1621

1722

1823
class TestMemPlanningPasses(unittest.TestCase):
@@ -762,3 +767,39 @@ def forward(self, x, y):
762767
)
763768
self.assertEqual(count_node(graph_module, memory.view), 1)
764769
self.verify_nop_memory_alloc(graph_module)
770+
771+
772+
def test_start_alignment_constraints(self):
773+
class Model(torch.nn.Module):
774+
def __init__(self):
775+
super().__init__()
776+
777+
def forward(self, x: torch.Tensor, y: torch.Tensor):
778+
add_0 = torch.add(x, y)
779+
add_1 = torch.add(x, add_0)
780+
add_2 = torch.add(add_0, add_1)
781+
add_3 = torch.add(add_1, add_2)
782+
return add_3
783+
784+
model = Model()
785+
inputs = (torch.randn(4, 17), torch.randn(4, 17))
786+
for mem_algo in range(0, 2):
787+
graph_module = (
788+
compiler.export_to_executorch_gen_etrecord(
789+
model,
790+
inputs,
791+
opt_level=1,
792+
mem_algo=mem_algo,
793+
alloc_graph_input=False,
794+
alloc_graph_output=False,
795+
mem_alignment=32,
796+
)
797+
.exported_program()
798+
.graph_module
799+
)
800+
# Assert that all memory allocations are aligned to 32B start address
801+
for spec in collect_specs_from_nodes(
802+
graph_module.graph.nodes, False, False
803+
):
804+
if spec and spec.mem_offset:
805+
self.assertEqual(spec.mem_offset % 32, 0)

0 commit comments

Comments
 (0)