Skip to content

Commit f145ced

Browse files
authored
Port memory planning to Cadence
Differential Revision: D64406681 Pull Request resolved: #6716
1 parent 88df185 commit f145ced

File tree

2 files changed

+392
-1
lines changed

2 files changed

+392
-1
lines changed
Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import collections
8+
import itertools
9+
import logging
10+
from functools import partial
11+
from typing import Iterable, List, Optional, Tuple
12+
13+
import torch
14+
from executorch.backends.cadence.aot.utils import MemoryConfig
15+
16+
from executorch.exir import ExecutorchProgramManager
17+
from executorch.exir.memory_planning import collect_specs_from_nodes, Verifier
18+
from executorch.exir.passes import MemoryPlanningPass
19+
from executorch.exir.tensor import TensorSpec
20+
from tabulate import tabulate
21+
from torch.export.exported_program import ExportGraphSignature
22+
from torch.fx.passes.infra.pass_base import PassResult
23+
24+
25+
# get num memories indexed from 1..N, compatible with EXIR's spec.mem_id
26+
def get_num_memories(memory_config: MemoryConfig) -> int:
27+
return len(memory_config.memory_sizes) + 1
28+
29+
30+
# memory_space module provides num_memories indexed 0..num_memories-1.
31+
def get_size(memory_config: MemoryConfig, exir_id: int) -> int:
32+
return memory_config.memory_sizes[exir_id - 1]
33+
34+
35+
def collect_specs_from_graph_module(
36+
graph_module: torch.fx.GraphModule,
37+
alloc_graph_input: bool,
38+
alloc_graph_output: bool,
39+
) -> Iterable[TensorSpec]:
40+
"""
41+
Return the specs for all the nodes in the graph module in
42+
topological order.
43+
"""
44+
# Collect the specs from all the nodes in the graph module, and return it
45+
return collect_specs_from_nodes(
46+
graph_module.graph.nodes,
47+
ignore_graph_input=not alloc_graph_input,
48+
ignore_graph_output=not alloc_graph_output,
49+
)
50+
51+
52+
# baseline tensor placement algorithm, that greedily tries to place the tensor in
53+
# the fastest memory available
54+
def position_based_greedy_with_hierarchy(
55+
graph_module: torch.fx.GraphModule,
56+
alignment: int,
57+
graph_signature: ExportGraphSignature,
58+
alloc_graph_input: bool,
59+
alloc_graph_output: bool,
60+
*,
61+
memory_config: MemoryConfig,
62+
) -> List[int]:
63+
num_memories = get_num_memories(memory_config)
64+
bufsizes = [0] * num_memories
65+
allocated_buffers: List[List[TensorSpec]] = [[] for _ in range(num_memories)]
66+
67+
def overlap(spec: TensorSpec) -> Optional[TensorSpec]:
68+
for allocated_spec in allocated_buffers[spec.mem_id]:
69+
if Verifier.lifetime_overlap(
70+
spec, allocated_spec
71+
) and Verifier.storage_overlap(spec, allocated_spec):
72+
return allocated_spec
73+
return None
74+
75+
def memory_available(spec: TensorSpec) -> bool:
76+
return spec.mem_offset + spec.allocated_memory <= get_size(
77+
memory_config, spec.mem_id
78+
)
79+
80+
# Iterate over all the specs in sorted order
81+
for spec in sorted(
82+
collect_specs_from_graph_module(
83+
graph_module, alloc_graph_input, alloc_graph_output
84+
),
85+
key=lambda spec: spec.allocated_memory,
86+
reverse=True,
87+
):
88+
for spec.mem_id in range(1, num_memories):
89+
spec.mem_offset = 0
90+
while memory_available(spec) and (overlapped := overlap(spec)):
91+
spec.mem_offset = overlapped.mem_offset + overlapped.allocated_memory
92+
if memory_available(spec):
93+
allocated_buffers[spec.mem_id].append(spec)
94+
bufsizes[spec.mem_id] = max(
95+
spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id]
96+
)
97+
break
98+
if (
99+
not allocated_buffers[spec.mem_id]
100+
or allocated_buffers[spec.mem_id][-1] is not spec
101+
):
102+
raise MemoryError(f"Cannot fit {spec} in any memory hierarchy")
103+
104+
logging.debug(
105+
f"position based greedy algorithm with hierarchy returns bufsizes: {bufsizes}"
106+
)
107+
return bufsizes
108+
109+
110+
# Greedy tensor placement with the heuristics from arxiv.org/pdf/2001.03288.pdf
111+
def greedy_by_size_for_offset_calculation_with_hierarchy(
112+
graph_module: torch.fx.GraphModule,
113+
alignment: int,
114+
graph_signature: ExportGraphSignature,
115+
alloc_graph_input: bool,
116+
alloc_graph_output: bool,
117+
*,
118+
memory_config: MemoryConfig,
119+
) -> List[int]:
120+
num_memories = get_num_memories(memory_config)
121+
bufsizes = [0] * num_memories
122+
allocated_buffers = [[] for _ in range(num_memories)]
123+
124+
# Iterate over all the specs in sorted order
125+
for spec in sorted(
126+
collect_specs_from_graph_module(
127+
graph_module, alloc_graph_input, alloc_graph_output
128+
),
129+
key=lambda spec: spec.allocated_memory,
130+
reverse=True,
131+
):
132+
for spec.mem_id in range(1, num_memories):
133+
prev_offset, smallest_gap = 0, float("inf")
134+
for allocated_spec in allocated_buffers[spec.mem_id]:
135+
if Verifier.lifetime_overlap(spec, allocated_spec):
136+
if (
137+
gap := allocated_spec.mem_offset - prev_offset
138+
) >= spec.allocated_memory and gap < smallest_gap:
139+
smallest_gap = gap
140+
spec.mem_offset = prev_offset
141+
# Note that different from the paper, which updates prev_offset for all
142+
# allocated tensors, we only update tensors with overlapping lifetime.
143+
# Updating prev_offset outside the if statement will include tensors without
144+
# overlapping lifetime, causing unnecessary waste of memory and make the
145+
# calculation of gap incorrect. Moving it out will make the algorithm degenerate
146+
# to the naive one, reusing 0 tensor. The paper may have a typo here.
147+
prev_offset = max(
148+
allocated_spec.mem_offset + allocated_spec.allocated_memory,
149+
prev_offset,
150+
)
151+
if spec.mem_offset is None:
152+
if prev_offset + spec.allocated_memory > get_size(
153+
memory_config, spec.mem_id
154+
):
155+
continue
156+
else:
157+
spec.mem_offset = prev_offset
158+
bufsizes[spec.mem_id] = max(
159+
spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id]
160+
)
161+
allocated_buffers[spec.mem_id].append(spec)
162+
allocated_buffers[spec.mem_id].sort(key=lambda spec: spec.mem_offset)
163+
# A data structure used for maintaining the tensor order
164+
# by offset, named ordered_allocated_ids in the paper
165+
break
166+
if spec not in allocated_buffers[spec.mem_id]:
167+
raise MemoryError(f"Cannot fit {spec} in any memory hierarchy")
168+
169+
logging.debug(
170+
f"greedy by size for offset calculation with hierarchy returns bufsizes: {bufsizes}"
171+
)
172+
return bufsizes
173+
174+
175+
def find_peak_memory_usages_per_memory(
176+
graph_module: torch.fx.GraphModule,
177+
alloc_graph_input: bool,
178+
alloc_graph_output: bool,
179+
) -> List[int]:
180+
"""
181+
Given a GraphModule with a memory plan, find the peak memory usages for each memory
182+
in the memory hierarchy.
183+
"""
184+
# Create a defaultdict to keep track of memory usages: {mem_id: mem_usage}
185+
# Use a defaultdict here because we don't know how many unique memory_id in
186+
# the memory hierarchy used in memory planning.
187+
usages = collections.defaultdict(int)
188+
189+
# go through all nodes in the graph, collect memory usage per spec.mem_id
190+
for spec in collect_specs_from_graph_module(
191+
graph_module, alloc_graph_input, alloc_graph_output
192+
):
193+
usages[spec.mem_id] = max(
194+
usages[spec.mem_id], spec.mem_offset + spec.allocated_memory
195+
)
196+
197+
# Convert usages dictionary into list of len of max memory id
198+
# Ex: {1: 20, 3:30} -> [0, 20, 0, 30].
199+
# ^ ^ ^ ^
200+
# | | | |_ mem_id 3
201+
# | | |_ mem_id 2
202+
# | |_ mem_id 1
203+
# |_ mem_id 0
204+
max_mem_id = max(usages.keys(), default=0)
205+
usages = [usages[i] for i in range(1, max_mem_id + 1)]
206+
207+
return usages
208+
209+
210+
def find_peak_memory_usage(
211+
graph_module: torch.fx.GraphModule,
212+
alloc_graph_input: bool,
213+
alloc_graph_output: bool,
214+
) -> Tuple[int, int]:
215+
"""
216+
Given a GraphModule with a memory plan, find the peak usage over time across all
217+
memories in the memory hierarchy. The resulting peak memory usage should be:
218+
1. >= min(find_peak_memory_usages_per_memory(graph_module))
219+
2. <= sum(find_peak_memory_usages_per_memory(graph_module))
220+
"""
221+
# memory allocations over time (measured in nodex index)
222+
byte_allocated = [0] * (len(graph_module.graph.nodes) + 1)
223+
224+
# Iterate over all the node specs
225+
for spec in collect_specs_from_graph_module(
226+
graph_module, alloc_graph_input, alloc_graph_output
227+
):
228+
if spec.lifetime[0] is None:
229+
continue
230+
231+
# lifetime is [start, end], both ends inclusive
232+
start, end = spec.lifetime
233+
byte_allocated[start] += spec.allocated_memory
234+
byte_allocated[end + 1] -= spec.allocated_memory
235+
236+
# accumulate the bytes allocated/deallocated to get memory usages
237+
memory_usages = list(itertools.accumulate(byte_allocated))
238+
239+
# find the peak memory usage and the index
240+
peak_memory_usage = max(memory_usages, default=0)
241+
peak_memory_usage_node_idx = (
242+
memory_usages.index(peak_memory_usage) if memory_usages else 0
243+
)
244+
245+
return peak_memory_usage, peak_memory_usage_node_idx
246+
247+
248+
# Print two tables with relevant memory planning information
249+
#
250+
# Per Memory Space Usage Table:
251+
# +--------------------------------------+----------------+-----------------------+-----------------------------+
252+
# | Memory Space | Base Address | Memory Size (Bytes) | Peak Memory Usage (Bytes) |
253+
# +======================================+================+=======================+=============================+
254+
# | MEMORY SPACE A | 0x57be0000 | 65213 | 64544 |
255+
# | MEMORY SPACE B | 0x57bf0000 | 65521 | 36864 |
256+
# | MEMORY SPACE ... | ... | ... | ... |
257+
# +--------------------------------------+----------------+-----------------------+-----------------------------+
258+
#
259+
# Total Memory Space Usage Table:
260+
# +-------------------------------------+---------------+---------+
261+
# | Peak memory usage across all spaces | 2380032 bytes | Node 86 |
262+
# +-------------------------------------+---------------+---------+
263+
def print_memory_planning_info(
264+
# pyre-fixme[11]: Annotation `ExecutorchProgramManager` is not defined as a type.
265+
executorch_prog: ExecutorchProgramManager,
266+
memory_config: MemoryConfig,
267+
alloc_graph_input: bool,
268+
alloc_graph_output: bool,
269+
) -> None:
270+
# Get the peak memory usages per memory space
271+
peak_memory_usages_per_memory = find_peak_memory_usages_per_memory(
272+
executorch_prog.exported_program().graph_module,
273+
alloc_graph_input,
274+
alloc_graph_output,
275+
)
276+
277+
# Create a table of memory spaces and their base addresses, total memory sizes, and peak memory usage
278+
memory_names, base_addrs = memory_config.memory_names, memory_config.base_addrs
279+
memory_usage_table = [
280+
[
281+
f"{(i + 1) if memory_names is None else memory_names[i]}",
282+
None if base_addrs is None else hex(base_addrs[i]),
283+
memory_config.memory_sizes[i],
284+
peak_memory_usages_per_memory[i],
285+
]
286+
for i in range(len(peak_memory_usages_per_memory))
287+
]
288+
289+
# Print the memory usage per memory space as a table
290+
logging.info(
291+
tabulate(
292+
memory_usage_table,
293+
headers=[
294+
"Memory Space",
295+
"Base Address",
296+
"Memory Size (Bytes)",
297+
"Peak Memory Usage (Bytes)",
298+
],
299+
tablefmt="outline",
300+
)
301+
)
302+
303+
# Get the total peak memory usage across all memory spaces
304+
total_peak_memory_usage = find_peak_memory_usage(
305+
executorch_prog.exported_program().graph_module,
306+
alloc_graph_input,
307+
alloc_graph_output,
308+
)
309+
310+
# Create a table with total peak memory usage and node at which this occurs
311+
total_memory_usage_table = [
312+
[
313+
"Peak memory usage across all spaces",
314+
f"{total_peak_memory_usage[0]} bytes",
315+
f"Node {total_peak_memory_usage[1]}",
316+
]
317+
]
318+
319+
# Print the total memory usage as a table
320+
logging.info(
321+
tabulate(
322+
total_memory_usage_table,
323+
tablefmt="outline",
324+
)
325+
)
326+
327+
328+
class CadenceMemoryPlanning:
329+
def __init__(
330+
self,
331+
memory_config: MemoryConfig,
332+
mem_algo: int,
333+
alloc_graph_input: bool = True,
334+
alloc_graph_output: bool = True,
335+
) -> None:
336+
self._init_mem_algos()
337+
338+
self.memory_config = memory_config
339+
self.mem_algo = mem_algo
340+
self.alloc_graph_input = alloc_graph_input
341+
self.alloc_graph_output = alloc_graph_output
342+
343+
def _init_mem_algos(self) -> None:
344+
self.available_mem_algos = [
345+
position_based_greedy_with_hierarchy,
346+
greedy_by_size_for_offset_calculation_with_hierarchy,
347+
]
348+
349+
def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
350+
algo = partial(
351+
self.available_mem_algos[self.mem_algo],
352+
memory_config=self.memory_config,
353+
)
354+
# Create the memory planning pass. We allocate memory for input
355+
# (output) tensors if alloc_graph_input (alloc_graph_output) is
356+
# True.
357+
mem_planning = MemoryPlanningPass(
358+
algo,
359+
allow_lifetime_and_storage_overlap=False,
360+
alloc_graph_input=self.alloc_graph_input,
361+
alloc_graph_output=self.alloc_graph_output,
362+
)
363+
mem_planning(graph_module)
364+
365+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)