Skip to content

Update CadenceMemoryPlanning to support per-memory alignment constraint #8689

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ def export_to_executorch_gen_etrecord(
alloc_graph_output: bool = True,
memory_config: Optional[MemoryConfig] = None,
dump_graphs: bool = False,
mem_alignment: int = 1,
) -> ExecutorchProgramManager:
cadence_passes = get_cadence_passes(opt_level)
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)
Expand All @@ -291,7 +290,6 @@ def export_to_executorch_gen_etrecord(
mem_algo=mem_algo,
alloc_graph_input=alloc_graph_input,
alloc_graph_output=alloc_graph_output,
mem_alignment=mem_alignment,
)

# Get executorch program after Cadence specific passes
Expand Down
30 changes: 21 additions & 9 deletions backends/cadence/aot/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def get_size(memory_config: MemoryConfig, exir_id: int) -> int:
return memory_config.memory_sizes[exir_id - 1]


def get_alignment(memory_config: MemoryConfig, exir_id: int) -> int:
# EXIR's spec.mem_id is indexed from 1..N.
assert memory_config.memory_alignments is not None
return memory_config.memory_alignments[exir_id - 1]


def get_aligned_offset(pre_aligned_offset: int, alignment: int) -> int:
return int(math.ceil(pre_aligned_offset / alignment) * alignment)

Expand Down Expand Up @@ -84,6 +90,10 @@ def position_based_greedy_with_hierarchy(
]
] = None,
) -> List[int]:
# We do not use the `alignment` parameter and instead use the per-memory alignment
# constraints from `memory_config`.
del alignment

num_memories = get_num_memories(memory_config)
bufsizes = [0] * num_memories
allocated_buffers: List[List[TensorSpec]] = [[] for _ in range(num_memories)]
Expand All @@ -103,7 +113,8 @@ def overlap(spec: TensorSpec) -> Optional[TensorSpec]:

def memory_available(spec: TensorSpec) -> bool:
return get_aligned_offset(
spec.mem_offset + spec.allocated_memory, alignment
spec.mem_offset + spec.allocated_memory,
get_alignment(memory_config, spec.mem_id),
) <= get_size(memory_config, spec.mem_id)

# Iterate over all the specs in sorted order
Expand All @@ -124,7 +135,8 @@ def memory_available(spec: TensorSpec) -> bool:
spec.mem_offset = 0
while memory_available(spec) and (overlapped := overlap(spec)):
spec.mem_offset = get_aligned_offset(
overlapped.mem_offset + overlapped.allocated_memory, alignment
overlapped.mem_offset + overlapped.allocated_memory,
get_alignment(memory_config, spec.mem_id),
)
if memory_available(spec):
allocated_buffers[spec.mem_id].append(spec)
Expand Down Expand Up @@ -172,6 +184,10 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
]
] = None,
) -> List[int]:
# We do not use the `alignment` parameter and instead use the per-memory alignment
# constraints from `memory_config`.
del alignment

num_memories = get_num_memories(memory_config)
bufsizes = [0] * num_memories
allocated_buffers = [[] for _ in range(num_memories)]
Expand Down Expand Up @@ -213,13 +229,14 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
prev_offset = max(
get_aligned_offset(
allocated_spec.mem_offset + allocated_spec.allocated_memory,
alignment,
get_alignment(memory_config, spec.mem_id),
),
prev_offset,
)
if spec.mem_offset is None:
if get_aligned_offset(
prev_offset + spec.allocated_memory, alignment
prev_offset + spec.allocated_memory,
get_alignment(memory_config, spec.mem_id),
) > get_size(memory_config, spec.mem_id):
continue
else:
Expand Down Expand Up @@ -439,7 +456,6 @@ def __init__(
]
]
] = None,
mem_alignment: int = 1,
) -> None:
self._init_mem_algos()

Expand All @@ -450,9 +466,6 @@ def __init__(
self.alloc_graph_output = alloc_graph_output
self.additional_constraint_gen_passes = additional_constraint_gen_passes

assert mem_alignment > 0, "mem_alignment must be positive"
self.mem_alignment = mem_alignment

def _init_mem_algos(self) -> None:
self.available_mem_algos = [
position_based_greedy_with_hierarchy,
Expand Down Expand Up @@ -489,7 +502,6 @@ def run(
allow_lifetime_and_storage_overlap=(self.opt_level >= 2),
alloc_graph_input=self.alloc_graph_input,
alloc_graph_output=self.alloc_graph_output,
alignment=self.mem_alignment,
)
mem_planning.run(graph_module, graph_signature)

Expand Down
5 changes: 4 additions & 1 deletion backends/cadence/aot/tests/test_memory_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from executorch.backends.cadence.aot import compiler
from executorch.backends.cadence.aot.memory_planning import find_peak_memory_usage
from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.backends.cadence.aot.utils import MemoryConfig
from executorch.exir import memory
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.memory_planning import collect_specs_from_nodes
Expand Down Expand Up @@ -792,7 +793,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
mem_algo=mem_algo,
alloc_graph_input=False,
alloc_graph_output=False,
mem_alignment=37,
memory_config=MemoryConfig(
memory_sizes=[0x1000000000], memory_alignments=[37]
),
)
.exported_program()
.graph_module
Expand Down
6 changes: 6 additions & 0 deletions backends/cadence/aot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,19 @@ def save_bpte_program(
@dataclass
class MemoryConfig:
memory_sizes: List[int]
# Alignment constraint for each memory region in bytes.
memory_alignments: Optional[List[int]] = None

# Optional fields for logs
memory_names: Optional[List[str]] = None
base_addrs: Optional[List[int]] = None
memory_xml_path: Optional[str] = None
MemorySpace: Optional[enum.Enum] = None

def __post_init__(self) -> None:
if self.memory_alignments is None:
self.memory_alignments = [1] * len(self.memory_sizes)

# get num memories indexed from 1..N, compatible with EXIR's spec.mem_id
def get_num_memories(self) -> int:
return len(self.memory_sizes) + 1
Expand Down
Loading