|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import collections |
| 8 | +import itertools |
| 9 | +import logging |
| 10 | +from functools import partial |
| 11 | +from typing import Iterable, List, Optional, Tuple |
| 12 | + |
| 13 | +import torch |
| 14 | +from executorch.backends.cadence.aot.utils import MemoryConfig |
| 15 | + |
| 16 | +from executorch.exir import ExecutorchProgramManager |
| 17 | +from executorch.exir.memory_planning import collect_specs_from_nodes, Verifier |
| 18 | +from executorch.exir.passes import MemoryPlanningPass |
| 19 | +from executorch.exir.tensor import TensorSpec |
| 20 | +from tabulate import tabulate |
| 21 | +from torch.export.exported_program import ExportGraphSignature |
| 22 | +from torch.fx.passes.infra.pass_base import PassResult |
| 23 | + |
| 24 | + |
| 25 | +# get num memories indexed from 1..N, compatible with EXIR's spec.mem_id |
| 26 | +def get_num_memories(memory_config: MemoryConfig) -> int: |
| 27 | + return len(memory_config.memory_sizes) + 1 |
| 28 | + |
| 29 | + |
| 30 | +# memory_space module provides num_memories indexed 0..num_memories-1. |
| 31 | +def get_size(memory_config: MemoryConfig, exir_id: int) -> int: |
| 32 | + return memory_config.memory_sizes[exir_id - 1] |
| 33 | + |
| 34 | + |
| 35 | +def collect_specs_from_graph_module( |
| 36 | + graph_module: torch.fx.GraphModule, |
| 37 | + alloc_graph_input: bool, |
| 38 | + alloc_graph_output: bool, |
| 39 | +) -> Iterable[TensorSpec]: |
| 40 | + """ |
| 41 | + Return the specs for all the nodes in the graph module in |
| 42 | + topological order. |
| 43 | + """ |
| 44 | + # Collect the specs from all the nodes in the graph module, and return it |
| 45 | + return collect_specs_from_nodes( |
| 46 | + graph_module.graph.nodes, |
| 47 | + ignore_graph_input=not alloc_graph_input, |
| 48 | + ignore_graph_output=not alloc_graph_output, |
| 49 | + ) |
| 50 | + |
| 51 | + |
| 52 | +# baseline tensor placement algorithm, that greedily tries to place the tensor in |
| 53 | +# the fastest memory available |
| 54 | +def position_based_greedy_with_hierarchy( |
| 55 | + graph_module: torch.fx.GraphModule, |
| 56 | + alignment: int, |
| 57 | + graph_signature: ExportGraphSignature, |
| 58 | + alloc_graph_input: bool, |
| 59 | + alloc_graph_output: bool, |
| 60 | + *, |
| 61 | + memory_config: MemoryConfig, |
| 62 | +) -> List[int]: |
| 63 | + num_memories = get_num_memories(memory_config) |
| 64 | + bufsizes = [0] * num_memories |
| 65 | + allocated_buffers: List[List[TensorSpec]] = [[] for _ in range(num_memories)] |
| 66 | + |
| 67 | + def overlap(spec: TensorSpec) -> Optional[TensorSpec]: |
| 68 | + for allocated_spec in allocated_buffers[spec.mem_id]: |
| 69 | + if Verifier.lifetime_overlap( |
| 70 | + spec, allocated_spec |
| 71 | + ) and Verifier.storage_overlap(spec, allocated_spec): |
| 72 | + return allocated_spec |
| 73 | + return None |
| 74 | + |
| 75 | + def memory_available(spec: TensorSpec) -> bool: |
| 76 | + return spec.mem_offset + spec.allocated_memory <= get_size( |
| 77 | + memory_config, spec.mem_id |
| 78 | + ) |
| 79 | + |
| 80 | + # Iterate over all the specs in sorted order |
| 81 | + for spec in sorted( |
| 82 | + collect_specs_from_graph_module( |
| 83 | + graph_module, alloc_graph_input, alloc_graph_output |
| 84 | + ), |
| 85 | + key=lambda spec: spec.allocated_memory, |
| 86 | + reverse=True, |
| 87 | + ): |
| 88 | + for spec.mem_id in range(1, num_memories): |
| 89 | + spec.mem_offset = 0 |
| 90 | + while memory_available(spec) and (overlapped := overlap(spec)): |
| 91 | + spec.mem_offset = overlapped.mem_offset + overlapped.allocated_memory |
| 92 | + if memory_available(spec): |
| 93 | + allocated_buffers[spec.mem_id].append(spec) |
| 94 | + bufsizes[spec.mem_id] = max( |
| 95 | + spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id] |
| 96 | + ) |
| 97 | + break |
| 98 | + if ( |
| 99 | + not allocated_buffers[spec.mem_id] |
| 100 | + or allocated_buffers[spec.mem_id][-1] is not spec |
| 101 | + ): |
| 102 | + raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") |
| 103 | + |
| 104 | + logging.debug( |
| 105 | + f"position based greedy algorithm with hierarchy returns bufsizes: {bufsizes}" |
| 106 | + ) |
| 107 | + return bufsizes |
| 108 | + |
| 109 | + |
| 110 | +# Greedy tensor placement with the heuristics from arxiv.org/pdf/2001.03288.pdf |
| 111 | +def greedy_by_size_for_offset_calculation_with_hierarchy( |
| 112 | + graph_module: torch.fx.GraphModule, |
| 113 | + alignment: int, |
| 114 | + graph_signature: ExportGraphSignature, |
| 115 | + alloc_graph_input: bool, |
| 116 | + alloc_graph_output: bool, |
| 117 | + *, |
| 118 | + memory_config: MemoryConfig, |
| 119 | +) -> List[int]: |
| 120 | + num_memories = get_num_memories(memory_config) |
| 121 | + bufsizes = [0] * num_memories |
| 122 | + allocated_buffers = [[] for _ in range(num_memories)] |
| 123 | + |
| 124 | + # Iterate over all the specs in sorted order |
| 125 | + for spec in sorted( |
| 126 | + collect_specs_from_graph_module( |
| 127 | + graph_module, alloc_graph_input, alloc_graph_output |
| 128 | + ), |
| 129 | + key=lambda spec: spec.allocated_memory, |
| 130 | + reverse=True, |
| 131 | + ): |
| 132 | + for spec.mem_id in range(1, num_memories): |
| 133 | + prev_offset, smallest_gap = 0, float("inf") |
| 134 | + for allocated_spec in allocated_buffers[spec.mem_id]: |
| 135 | + if Verifier.lifetime_overlap(spec, allocated_spec): |
| 136 | + if ( |
| 137 | + gap := allocated_spec.mem_offset - prev_offset |
| 138 | + ) >= spec.allocated_memory and gap < smallest_gap: |
| 139 | + smallest_gap = gap |
| 140 | + spec.mem_offset = prev_offset |
| 141 | + # Note that different from the paper, which updates prev_offset for all |
| 142 | + # allocated tensors, we only update tensors with overlapping lifetime. |
| 143 | + # Updating prev_offset outside the if statement will include tensors without |
| 144 | + # overlapping lifetime, causing unnecessary waste of memory and make the |
| 145 | + # calculation of gap incorrect. Moving it out will make the algorithm degenerate |
| 146 | + # to the naive one, reusing 0 tensor. The paper may have a typo here. |
| 147 | + prev_offset = max( |
| 148 | + allocated_spec.mem_offset + allocated_spec.allocated_memory, |
| 149 | + prev_offset, |
| 150 | + ) |
| 151 | + if spec.mem_offset is None: |
| 152 | + if prev_offset + spec.allocated_memory > get_size( |
| 153 | + memory_config, spec.mem_id |
| 154 | + ): |
| 155 | + continue |
| 156 | + else: |
| 157 | + spec.mem_offset = prev_offset |
| 158 | + bufsizes[spec.mem_id] = max( |
| 159 | + spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id] |
| 160 | + ) |
| 161 | + allocated_buffers[spec.mem_id].append(spec) |
| 162 | + allocated_buffers[spec.mem_id].sort(key=lambda spec: spec.mem_offset) |
| 163 | + # A data structure used for maintaining the tensor order |
| 164 | + # by offset, named ordered_allocated_ids in the paper |
| 165 | + break |
| 166 | + if spec not in allocated_buffers[spec.mem_id]: |
| 167 | + raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") |
| 168 | + |
| 169 | + logging.debug( |
| 170 | + f"greedy by size for offset calculation with hierarchy returns bufsizes: {bufsizes}" |
| 171 | + ) |
| 172 | + return bufsizes |
| 173 | + |
| 174 | + |
| 175 | +def find_peak_memory_usages_per_memory( |
| 176 | + graph_module: torch.fx.GraphModule, |
| 177 | + alloc_graph_input: bool, |
| 178 | + alloc_graph_output: bool, |
| 179 | +) -> List[int]: |
| 180 | + """ |
| 181 | + Given a GraphModule with a memory plan, find the peak memory usages for each memory |
| 182 | + in the memory hierarchy. |
| 183 | + """ |
| 184 | + # Create a defaultdict to keep track of memory usages: {mem_id: mem_usage} |
| 185 | + # Use a defaultdict here because we don't know how many unique memory_id in |
| 186 | + # the memory hierarchy used in memory planning. |
| 187 | + usages = collections.defaultdict(int) |
| 188 | + |
| 189 | + # go through all nodes in the graph, collect memory usage per spec.mem_id |
| 190 | + for spec in collect_specs_from_graph_module( |
| 191 | + graph_module, alloc_graph_input, alloc_graph_output |
| 192 | + ): |
| 193 | + usages[spec.mem_id] = max( |
| 194 | + usages[spec.mem_id], spec.mem_offset + spec.allocated_memory |
| 195 | + ) |
| 196 | + |
| 197 | + # Convert usages dictionary into list of len of max memory id |
| 198 | + # Ex: {1: 20, 3:30} -> [0, 20, 0, 30]. |
| 199 | + # ^ ^ ^ ^ |
| 200 | + # | | | |_ mem_id 3 |
| 201 | + # | | |_ mem_id 2 |
| 202 | + # | |_ mem_id 1 |
| 203 | + # |_ mem_id 0 |
| 204 | + max_mem_id = max(usages.keys(), default=0) |
| 205 | + usages = [usages[i] for i in range(1, max_mem_id + 1)] |
| 206 | + |
| 207 | + return usages |
| 208 | + |
| 209 | + |
| 210 | +def find_peak_memory_usage( |
| 211 | + graph_module: torch.fx.GraphModule, |
| 212 | + alloc_graph_input: bool, |
| 213 | + alloc_graph_output: bool, |
| 214 | +) -> Tuple[int, int]: |
| 215 | + """ |
| 216 | + Given a GraphModule with a memory plan, find the peak usage over time across all |
| 217 | + memories in the memory hierarchy. The resulting peak memory usage should be: |
| 218 | + 1. >= min(find_peak_memory_usages_per_memory(graph_module)) |
| 219 | + 2. <= sum(find_peak_memory_usages_per_memory(graph_module)) |
| 220 | + """ |
| 221 | + # memory allocations over time (measured in nodex index) |
| 222 | + byte_allocated = [0] * (len(graph_module.graph.nodes) + 1) |
| 223 | + |
| 224 | + # Iterate over all the node specs |
| 225 | + for spec in collect_specs_from_graph_module( |
| 226 | + graph_module, alloc_graph_input, alloc_graph_output |
| 227 | + ): |
| 228 | + if spec.lifetime[0] is None: |
| 229 | + continue |
| 230 | + |
| 231 | + # lifetime is [start, end], both ends inclusive |
| 232 | + start, end = spec.lifetime |
| 233 | + byte_allocated[start] += spec.allocated_memory |
| 234 | + byte_allocated[end + 1] -= spec.allocated_memory |
| 235 | + |
| 236 | + # accumulate the bytes allocated/deallocated to get memory usages |
| 237 | + memory_usages = list(itertools.accumulate(byte_allocated)) |
| 238 | + |
| 239 | + # find the peak memory usage and the index |
| 240 | + peak_memory_usage = max(memory_usages, default=0) |
| 241 | + peak_memory_usage_node_idx = ( |
| 242 | + memory_usages.index(peak_memory_usage) if memory_usages else 0 |
| 243 | + ) |
| 244 | + |
| 245 | + return peak_memory_usage, peak_memory_usage_node_idx |
| 246 | + |
| 247 | + |
| 248 | +# Print two tables with relevant memory planning information |
| 249 | +# |
| 250 | +# Per Memory Space Usage Table: |
| 251 | +# +--------------------------------------+----------------+-----------------------+-----------------------------+ |
| 252 | +# | Memory Space | Base Address | Memory Size (Bytes) | Peak Memory Usage (Bytes) | |
| 253 | +# +======================================+================+=======================+=============================+ |
| 254 | +# | MEMORY SPACE A | 0x57be0000 | 65213 | 64544 | |
| 255 | +# | MEMORY SPACE B | 0x57bf0000 | 65521 | 36864 | |
| 256 | +# | MEMORY SPACE ... | ... | ... | ... | |
| 257 | +# +--------------------------------------+----------------+-----------------------+-----------------------------+ |
| 258 | +# |
| 259 | +# Total Memory Space Usage Table: |
| 260 | +# +-------------------------------------+---------------+---------+ |
| 261 | +# | Peak memory usage across all spaces | 2380032 bytes | Node 86 | |
| 262 | +# +-------------------------------------+---------------+---------+ |
| 263 | +def print_memory_planning_info( |
| 264 | + # pyre-fixme[11]: Annotation `ExecutorchProgramManager` is not defined as a type. |
| 265 | + executorch_prog: ExecutorchProgramManager, |
| 266 | + memory_config: MemoryConfig, |
| 267 | + alloc_graph_input: bool, |
| 268 | + alloc_graph_output: bool, |
| 269 | +) -> None: |
| 270 | + # Get the peak memory usages per memory space |
| 271 | + peak_memory_usages_per_memory = find_peak_memory_usages_per_memory( |
| 272 | + executorch_prog.exported_program().graph_module, |
| 273 | + alloc_graph_input, |
| 274 | + alloc_graph_output, |
| 275 | + ) |
| 276 | + |
| 277 | + # Create a table of memory spaces and their base addresses, total memory sizes, and peak memory usage |
| 278 | + memory_names, base_addrs = memory_config.memory_names, memory_config.base_addrs |
| 279 | + memory_usage_table = [ |
| 280 | + [ |
| 281 | + f"{(i + 1) if memory_names is None else memory_names[i]}", |
| 282 | + None if base_addrs is None else hex(base_addrs[i]), |
| 283 | + memory_config.memory_sizes[i], |
| 284 | + peak_memory_usages_per_memory[i], |
| 285 | + ] |
| 286 | + for i in range(len(peak_memory_usages_per_memory)) |
| 287 | + ] |
| 288 | + |
| 289 | + # Print the memory usage per memory space as a table |
| 290 | + logging.info( |
| 291 | + tabulate( |
| 292 | + memory_usage_table, |
| 293 | + headers=[ |
| 294 | + "Memory Space", |
| 295 | + "Base Address", |
| 296 | + "Memory Size (Bytes)", |
| 297 | + "Peak Memory Usage (Bytes)", |
| 298 | + ], |
| 299 | + tablefmt="outline", |
| 300 | + ) |
| 301 | + ) |
| 302 | + |
| 303 | + # Get the total peak memory usage across all memory spaces |
| 304 | + total_peak_memory_usage = find_peak_memory_usage( |
| 305 | + executorch_prog.exported_program().graph_module, |
| 306 | + alloc_graph_input, |
| 307 | + alloc_graph_output, |
| 308 | + ) |
| 309 | + |
| 310 | + # Create a table with total peak memory usage and node at which this occurs |
| 311 | + total_memory_usage_table = [ |
| 312 | + [ |
| 313 | + "Peak memory usage across all spaces", |
| 314 | + f"{total_peak_memory_usage[0]} bytes", |
| 315 | + f"Node {total_peak_memory_usage[1]}", |
| 316 | + ] |
| 317 | + ] |
| 318 | + |
| 319 | + # Print the total memory usage as a table |
| 320 | + logging.info( |
| 321 | + tabulate( |
| 322 | + total_memory_usage_table, |
| 323 | + tablefmt="outline", |
| 324 | + ) |
| 325 | + ) |
| 326 | + |
| 327 | + |
| 328 | +class CadenceMemoryPlanning: |
| 329 | + def __init__( |
| 330 | + self, |
| 331 | + memory_config: MemoryConfig, |
| 332 | + mem_algo: int, |
| 333 | + alloc_graph_input: bool = True, |
| 334 | + alloc_graph_output: bool = True, |
| 335 | + ) -> None: |
| 336 | + self._init_mem_algos() |
| 337 | + |
| 338 | + self.memory_config = memory_config |
| 339 | + self.mem_algo = mem_algo |
| 340 | + self.alloc_graph_input = alloc_graph_input |
| 341 | + self.alloc_graph_output = alloc_graph_output |
| 342 | + |
| 343 | + def _init_mem_algos(self) -> None: |
| 344 | + self.available_mem_algos = [ |
| 345 | + position_based_greedy_with_hierarchy, |
| 346 | + greedy_by_size_for_offset_calculation_with_hierarchy, |
| 347 | + ] |
| 348 | + |
| 349 | + def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: |
| 350 | + algo = partial( |
| 351 | + self.available_mem_algos[self.mem_algo], |
| 352 | + memory_config=self.memory_config, |
| 353 | + ) |
| 354 | + # Create the memory planning pass. We allocate memory for input |
| 355 | + # (output) tensors if alloc_graph_input (alloc_graph_output) is |
| 356 | + # True. |
| 357 | + mem_planning = MemoryPlanningPass( |
| 358 | + algo, |
| 359 | + allow_lifetime_and_storage_overlap=False, |
| 360 | + alloc_graph_input=self.alloc_graph_input, |
| 361 | + alloc_graph_output=self.alloc_graph_output, |
| 362 | + ) |
| 363 | + mem_planning(graph_module) |
| 364 | + |
| 365 | + return PassResult(graph_module, True) |
0 commit comments