Skip to content

Port memory planning to Cadence #6716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
365 changes: 365 additions & 0 deletions backends/cadence/aot/memory_planning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,365 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import collections
import itertools
import logging
from functools import partial
from typing import Iterable, List, Optional, Tuple

import torch
from executorch.backends.cadence.aot.utils import MemoryConfig

from executorch.exir import ExecutorchProgramManager
from executorch.exir.memory_planning import collect_specs_from_nodes, Verifier
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.tensor import TensorSpec
from tabulate import tabulate
from torch.export.exported_program import ExportGraphSignature
from torch.fx.passes.infra.pass_base import PassResult


# get num memories indexed from 1..N, compatible with EXIR's spec.mem_id
def get_num_memories(memory_config: MemoryConfig) -> int:
return len(memory_config.memory_sizes) + 1


# memory_space module provides num_memories indexed 0..num_memories-1.
def get_size(memory_config: MemoryConfig, exir_id: int) -> int:
return memory_config.memory_sizes[exir_id - 1]


def collect_specs_from_graph_module(
graph_module: torch.fx.GraphModule,
alloc_graph_input: bool,
alloc_graph_output: bool,
) -> Iterable[TensorSpec]:
"""
Return the specs for all the nodes in the graph module in
topological order.
"""
# Collect the specs from all the nodes in the graph module, and return it
return collect_specs_from_nodes(
graph_module.graph.nodes,
ignore_graph_input=not alloc_graph_input,
ignore_graph_output=not alloc_graph_output,
)


# baseline tensor placement algorithm, that greedily tries to place the tensor in
# the fastest memory available
def position_based_greedy_with_hierarchy(
graph_module: torch.fx.GraphModule,
alignment: int,
graph_signature: ExportGraphSignature,
alloc_graph_input: bool,
alloc_graph_output: bool,
*,
memory_config: MemoryConfig,
) -> List[int]:
num_memories = get_num_memories(memory_config)
bufsizes = [0] * num_memories
allocated_buffers: List[List[TensorSpec]] = [[] for _ in range(num_memories)]

def overlap(spec: TensorSpec) -> Optional[TensorSpec]:
for allocated_spec in allocated_buffers[spec.mem_id]:
if Verifier.lifetime_overlap(
spec, allocated_spec
) and Verifier.storage_overlap(spec, allocated_spec):
return allocated_spec
return None

def memory_available(spec: TensorSpec) -> bool:
return spec.mem_offset + spec.allocated_memory <= get_size(
memory_config, spec.mem_id
)

# Iterate over all the specs in sorted order
for spec in sorted(
collect_specs_from_graph_module(
graph_module, alloc_graph_input, alloc_graph_output
),
key=lambda spec: spec.allocated_memory,
reverse=True,
):
for spec.mem_id in range(1, num_memories):
spec.mem_offset = 0
while memory_available(spec) and (overlapped := overlap(spec)):
spec.mem_offset = overlapped.mem_offset + overlapped.allocated_memory
if memory_available(spec):
allocated_buffers[spec.mem_id].append(spec)
bufsizes[spec.mem_id] = max(
spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id]
)
break
if (
not allocated_buffers[spec.mem_id]
or allocated_buffers[spec.mem_id][-1] is not spec
):
raise MemoryError(f"Cannot fit {spec} in any memory hierarchy")

logging.debug(
f"position based greedy algorithm with hierarchy returns bufsizes: {bufsizes}"
)
return bufsizes


# Greedy tensor placement with the heuristics from arxiv.org/pdf/2001.03288.pdf
def greedy_by_size_for_offset_calculation_with_hierarchy(
graph_module: torch.fx.GraphModule,
alignment: int,
graph_signature: ExportGraphSignature,
alloc_graph_input: bool,
alloc_graph_output: bool,
*,
memory_config: MemoryConfig,
) -> List[int]:
num_memories = get_num_memories(memory_config)
bufsizes = [0] * num_memories
allocated_buffers = [[] for _ in range(num_memories)]

# Iterate over all the specs in sorted order
for spec in sorted(
collect_specs_from_graph_module(
graph_module, alloc_graph_input, alloc_graph_output
),
key=lambda spec: spec.allocated_memory,
reverse=True,
):
for spec.mem_id in range(1, num_memories):
prev_offset, smallest_gap = 0, float("inf")
for allocated_spec in allocated_buffers[spec.mem_id]:
if Verifier.lifetime_overlap(spec, allocated_spec):
if (
gap := allocated_spec.mem_offset - prev_offset
) >= spec.allocated_memory and gap < smallest_gap:
smallest_gap = gap
spec.mem_offset = prev_offset
# Note that different from the paper, which updates prev_offset for all
# allocated tensors, we only update tensors with overlapping lifetime.
# Updating prev_offset outside the if statement will include tensors without
# overlapping lifetime, causing unnecessary waste of memory and make the
# calculation of gap incorrect. Moving it out will make the algorithm degenerate
# to the naive one, reusing 0 tensor. The paper may have a typo here.
prev_offset = max(
allocated_spec.mem_offset + allocated_spec.allocated_memory,
prev_offset,
)
if spec.mem_offset is None:
if prev_offset + spec.allocated_memory > get_size(
memory_config, spec.mem_id
):
continue
else:
spec.mem_offset = prev_offset
bufsizes[spec.mem_id] = max(
spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id]
)
allocated_buffers[spec.mem_id].append(spec)
allocated_buffers[spec.mem_id].sort(key=lambda spec: spec.mem_offset)
# A data structure used for maintaining the tensor order
# by offset, named ordered_allocated_ids in the paper
break
if spec not in allocated_buffers[spec.mem_id]:
raise MemoryError(f"Cannot fit {spec} in any memory hierarchy")

logging.debug(
f"greedy by size for offset calculation with hierarchy returns bufsizes: {bufsizes}"
)
return bufsizes


def find_peak_memory_usages_per_memory(
graph_module: torch.fx.GraphModule,
alloc_graph_input: bool,
alloc_graph_output: bool,
) -> List[int]:
"""
Given a GraphModule with a memory plan, find the peak memory usages for each memory
in the memory hierarchy.
"""
# Create a defaultdict to keep track of memory usages: {mem_id: mem_usage}
# Use a defaultdict here because we don't know how many unique memory_id in
# the memory hierarchy used in memory planning.
usages = collections.defaultdict(int)

# go through all nodes in the graph, collect memory usage per spec.mem_id
for spec in collect_specs_from_graph_module(
graph_module, alloc_graph_input, alloc_graph_output
):
usages[spec.mem_id] = max(
usages[spec.mem_id], spec.mem_offset + spec.allocated_memory
)

# Convert usages dictionary into list of len of max memory id
# Ex: {1: 20, 3:30} -> [0, 20, 0, 30].
# ^ ^ ^ ^
# | | | |_ mem_id 3
# | | |_ mem_id 2
# | |_ mem_id 1
# |_ mem_id 0
max_mem_id = max(usages.keys(), default=0)
usages = [usages[i] for i in range(1, max_mem_id + 1)]

return usages


def find_peak_memory_usage(
graph_module: torch.fx.GraphModule,
alloc_graph_input: bool,
alloc_graph_output: bool,
) -> Tuple[int, int]:
"""
Given a GraphModule with a memory plan, find the peak usage over time across all
memories in the memory hierarchy. The resulting peak memory usage should be:
1. >= min(find_peak_memory_usages_per_memory(graph_module))
2. <= sum(find_peak_memory_usages_per_memory(graph_module))
"""
# memory allocations over time (measured in nodex index)
byte_allocated = [0] * (len(graph_module.graph.nodes) + 1)

# Iterate over all the node specs
for spec in collect_specs_from_graph_module(
graph_module, alloc_graph_input, alloc_graph_output
):
if spec.lifetime[0] is None:
continue

# lifetime is [start, end], both ends inclusive
start, end = spec.lifetime
byte_allocated[start] += spec.allocated_memory
byte_allocated[end + 1] -= spec.allocated_memory

# accumulate the bytes allocated/deallocated to get memory usages
memory_usages = list(itertools.accumulate(byte_allocated))

# find the peak memory usage and the index
peak_memory_usage = max(memory_usages, default=0)
peak_memory_usage_node_idx = (
memory_usages.index(peak_memory_usage) if memory_usages else 0
)

return peak_memory_usage, peak_memory_usage_node_idx


# Print two tables with relevant memory planning information
#
# Per Memory Space Usage Table:
# +--------------------------------------+----------------+-----------------------+-----------------------------+
# | Memory Space | Base Address | Memory Size (Bytes) | Peak Memory Usage (Bytes) |
# +======================================+================+=======================+=============================+
# | MEMORY SPACE A | 0x57be0000 | 65213 | 64544 |
# | MEMORY SPACE B | 0x57bf0000 | 65521 | 36864 |
# | MEMORY SPACE ... | ... | ... | ... |
# +--------------------------------------+----------------+-----------------------+-----------------------------+
#
# Total Memory Space Usage Table:
# +-------------------------------------+---------------+---------+
# | Peak memory usage across all spaces | 2380032 bytes | Node 86 |
# +-------------------------------------+---------------+---------+
def print_memory_planning_info(
# pyre-fixme[11]: Annotation `ExecutorchProgramManager` is not defined as a type.
executorch_prog: ExecutorchProgramManager,
memory_config: MemoryConfig,
alloc_graph_input: bool,
alloc_graph_output: bool,
) -> None:
# Get the peak memory usages per memory space
peak_memory_usages_per_memory = find_peak_memory_usages_per_memory(
executorch_prog.exported_program().graph_module,
alloc_graph_input,
alloc_graph_output,
)

# Create a table of memory spaces and their base addresses, total memory sizes, and peak memory usage
memory_names, base_addrs = memory_config.memory_names, memory_config.base_addrs
memory_usage_table = [
[
f"{(i + 1) if memory_names is None else memory_names[i]}",
None if base_addrs is None else hex(base_addrs[i]),
memory_config.memory_sizes[i],
peak_memory_usages_per_memory[i],
]
for i in range(len(peak_memory_usages_per_memory))
]

# Print the memory usage per memory space as a table
logging.info(
tabulate(
memory_usage_table,
headers=[
"Memory Space",
"Base Address",
"Memory Size (Bytes)",
"Peak Memory Usage (Bytes)",
],
tablefmt="outline",
)
)

# Get the total peak memory usage across all memory spaces
total_peak_memory_usage = find_peak_memory_usage(
executorch_prog.exported_program().graph_module,
alloc_graph_input,
alloc_graph_output,
)

# Create a table with total peak memory usage and node at which this occurs
total_memory_usage_table = [
[
"Peak memory usage across all spaces",
f"{total_peak_memory_usage[0]} bytes",
f"Node {total_peak_memory_usage[1]}",
]
]

# Print the total memory usage as a table
logging.info(
tabulate(
total_memory_usage_table,
tablefmt="outline",
)
)


class CadenceMemoryPlanning:
def __init__(
self,
memory_config: MemoryConfig,
mem_algo: int,
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
) -> None:
self._init_mem_algos()

self.memory_config = memory_config
self.mem_algo = mem_algo
self.alloc_graph_input = alloc_graph_input
self.alloc_graph_output = alloc_graph_output

def _init_mem_algos(self) -> None:
self.available_mem_algos = [
position_based_greedy_with_hierarchy,
greedy_by_size_for_offset_calculation_with_hierarchy,
]

def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
algo = partial(
self.available_mem_algos[self.mem_algo],
memory_config=self.memory_config,
)
# Create the memory planning pass. We allocate memory for input
# (output) tensors if alloc_graph_input (alloc_graph_output) is
# True.
mem_planning = MemoryPlanningPass(
algo,
allow_lifetime_and_storage_overlap=False,
alloc_graph_input=self.alloc_graph_input,
alloc_graph_output=self.alloc_graph_output,
)
mem_planning(graph_module)

return PassResult(graph_module, True)
Loading
Loading