Skip to content

Commit 83d3b48

Browse files
Eashan Gargfacebook-github-bot
authored andcommitted
Port memory planning to Cadence (#6716)
Summary: Porting memory planning over to Cadence OSS Reviewed By: hsharma35 Differential Revision: D64406681
1 parent 7895982 commit 83d3b48

File tree

2 files changed

+393
-1
lines changed

2 files changed

+393
-1
lines changed
Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import collections
8+
import itertools
9+
import logging
10+
from functools import partial
11+
from typing import Iterable, List, Optional, Tuple
12+
13+
import torch
14+
15+
from executorch.exir import ExecutorchProgramManager
16+
from executorch.exir.memory_planning import collect_specs_from_nodes, Verifier
17+
from executorch.exir.passes import MemoryPlanningPass
18+
from executorch.exir.tensor import TensorSpec
19+
from executorch.backends.cadence.aot.utils import MemoryConfig
20+
from tabulate import tabulate
21+
from torch.export.exported_program import ExportGraphSignature
22+
from torch.fx.passes.infra.pass_base import PassResult
23+
24+
25+
# get num memories indexed from 1..N, compatible with EXIR's spec.mem_id
26+
def get_num_memories(memory_config: MemoryConfig) -> int:
27+
return len(memory_config.memory_sizes) + 1
28+
29+
30+
# memory_space module provides num_memories indexed 0..num_memories-1.
31+
def get_size(memory_config: MemoryConfig, exir_id: int) -> int:
32+
return memory_config.memory_sizes[exir_id - 1]
33+
34+
35+
def collect_specs_from_graph_module(
36+
graph_module: torch.fx.GraphModule,
37+
alloc_graph_input: bool,
38+
alloc_graph_output: bool,
39+
) -> Iterable[TensorSpec]:
40+
"""
41+
Return the specs for all the nodes in the graph module in
42+
topological order.
43+
"""
44+
# Collect the specs from all the nodes in the graph module, and return it
45+
return collect_specs_from_nodes(
46+
graph_module.graph.nodes,
47+
ignore_graph_input=not alloc_graph_input,
48+
ignore_graph_output=not alloc_graph_output,
49+
)
50+
51+
52+
# baseline tensor placement algorithm, that greedily tries to place the tensor in
53+
# the fastest memory available
54+
def position_based_greedy_with_hierarchy(
55+
graph_module: torch.fx.GraphModule,
56+
alignment: int,
57+
graph_signature: ExportGraphSignature,
58+
alloc_graph_input: bool,
59+
alloc_graph_output: bool,
60+
*,
61+
memory_config: MemoryConfig,
62+
) -> List[int]:
63+
num_memories = get_num_memories(memory_config)
64+
bufsizes = [0] * num_memories
65+
allocated_buffers: List[List[TensorSpec]] = [[] for _ in range(num_memories)]
66+
67+
def overlap(spec: TensorSpec) -> Optional[TensorSpec]:
68+
for allocated_spec in allocated_buffers[spec.mem_id]:
69+
if Verifier.lifetime_overlap(
70+
spec, allocated_spec
71+
) and Verifier.storage_overlap(spec, allocated_spec):
72+
return allocated_spec
73+
return None
74+
75+
def memory_available(spec: TensorSpec) -> bool:
76+
return spec.mem_offset + spec.allocated_memory <= get_size(
77+
memory_config, spec.mem_id
78+
)
79+
80+
# Iterate over all the specs in sorted order
81+
for spec in sorted(
82+
collect_specs_from_graph_module(
83+
graph_module, alloc_graph_input, alloc_graph_output
84+
),
85+
key=lambda spec: spec.allocated_memory,
86+
reverse=True,
87+
):
88+
for spec.mem_id in range(1, num_memories):
89+
spec.mem_offset = 0
90+
while memory_available(spec) and (overlapped := overlap(spec)):
91+
spec.mem_offset = overlapped.mem_offset + overlapped.allocated_memory
92+
if memory_available(spec):
93+
allocated_buffers[spec.mem_id].append(spec)
94+
bufsizes[spec.mem_id] = max(
95+
spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id]
96+
)
97+
break
98+
if (
99+
not allocated_buffers[spec.mem_id]
100+
or allocated_buffers[spec.mem_id][-1] is not spec
101+
):
102+
raise MemoryError(f"Cannot fit {spec} in any memory hierarchy")
103+
104+
logging.debug(
105+
f"position based greedy algorithm with hierarchy returns bufsizes: {bufsizes}"
106+
)
107+
return bufsizes
108+
109+
110+
# Greedy tensor placement with the heuristics from arxiv.org/pdf/2001.03288.pdf
111+
def greedy_by_size_for_offset_calculation_with_hierarchy(
112+
graph_module: torch.fx.GraphModule,
113+
alignment: int,
114+
graph_signature: ExportGraphSignature,
115+
alloc_graph_input: bool,
116+
alloc_graph_output: bool,
117+
*,
118+
memory_config: MemoryConfig,
119+
) -> List[int]:
120+
num_memories = get_num_memories(memory_config)
121+
bufsizes = [0] * num_memories
122+
allocated_buffers = [[] for _ in range(num_memories)]
123+
124+
125+
# Iterate over all the specs in sorted order
126+
for spec in sorted(
127+
collect_specs_from_graph_module(
128+
graph_module, alloc_graph_input, alloc_graph_output
129+
),
130+
key=lambda spec: spec.allocated_memory,
131+
reverse=True,
132+
):
133+
for spec.mem_id in range(1, num_memories):
134+
prev_offset, smallest_gap = 0, float("inf")
135+
for allocated_spec in allocated_buffers[spec.mem_id]:
136+
if Verifier.lifetime_overlap(spec, allocated_spec):
137+
if (
138+
gap := allocated_spec.mem_offset - prev_offset
139+
) >= spec.allocated_memory and gap < smallest_gap:
140+
smallest_gap = gap
141+
spec.mem_offset = prev_offset
142+
# Note that different from the paper, which updates prev_offset for all
143+
# allocated tensors, we only update tensors with overlapping lifetime.
144+
# Updating prev_offset outside the if statement will include tensors without
145+
# overlapping lifetime, causing unnecessary waste of memory and make the
146+
# calculation of gap incorrect. Moving it out will make the algorithm degenerate
147+
# to the naive one, reusing 0 tensor. The paper may have a typo here.
148+
prev_offset = max(
149+
allocated_spec.mem_offset + allocated_spec.allocated_memory,
150+
prev_offset,
151+
)
152+
if spec.mem_offset is None:
153+
if prev_offset + spec.allocated_memory > get_size(
154+
memory_config, spec.mem_id
155+
):
156+
continue
157+
else:
158+
spec.mem_offset = prev_offset
159+
bufsizes[spec.mem_id] = max(
160+
spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id]
161+
)
162+
allocated_buffers[spec.mem_id].append(spec)
163+
allocated_buffers[spec.mem_id].sort(key=lambda spec: spec.mem_offset)
164+
# A data structure used for maintaining the tensor order
165+
# by offset, named ordered_allocated_ids in the paper
166+
break
167+
if spec not in allocated_buffers[spec.mem_id]:
168+
raise MemoryError(f"Cannot fit {spec} in any memory hierarchy")
169+
170+
logging.debug(
171+
f"greedy by size for offset calculation with hierarchy returns bufsizes: {bufsizes}"
172+
)
173+
return bufsizes
174+
175+
176+
def find_peak_memory_usages_per_memory(
177+
graph_module: torch.fx.GraphModule,
178+
alloc_graph_input: bool,
179+
alloc_graph_output: bool,
180+
) -> List[int]:
181+
"""
182+
Given a GraphModule with a memory plan, find the peak memory usages for each memory
183+
in the memory hierarchy.
184+
"""
185+
# Create a defaultdict to keep track of memory usages: {mem_id: mem_usage}
186+
# Use a defaultdict here because we don't know how many unique memory_id in
187+
# the memory hierarchy used in memory planning.
188+
usages = collections.defaultdict(int)
189+
190+
# go through all nodes in the graph, collect memory usage per spec.mem_id
191+
for spec in collect_specs_from_graph_module(
192+
graph_module, alloc_graph_input, alloc_graph_output
193+
):
194+
usages[spec.mem_id] = max(
195+
usages[spec.mem_id], spec.mem_offset + spec.allocated_memory
196+
)
197+
198+
# Convert usages dictionary into list of len of max memory id
199+
# Ex: {1: 20, 3:30} -> [0, 20, 0, 30].
200+
# ^ ^ ^ ^
201+
# | | | |_ mem_id 3
202+
# | | |_ mem_id 2
203+
# | |_ mem_id 1
204+
# |_ mem_id 0
205+
max_mem_id = max(usages.keys(), default=0)
206+
usages = [usages[i] for i in range(1, max_mem_id + 1)]
207+
208+
return usages
209+
210+
211+
def find_peak_memory_usage(
212+
graph_module: torch.fx.GraphModule,
213+
alloc_graph_input: bool,
214+
alloc_graph_output: bool,
215+
) -> Tuple[int, int]:
216+
"""
217+
Given a GraphModule with a memory plan, find the peak usage over time across all
218+
memories in the memory hierarchy. The resulting peak memory usage should be:
219+
1. >= min(find_peak_memory_usages_per_memory(graph_module))
220+
2. <= sum(find_peak_memory_usages_per_memory(graph_module))
221+
"""
222+
# memory allocations over time (measured in nodex index)
223+
byte_allocated = [0] * (len(graph_module.graph.nodes) + 1)
224+
225+
# Iterate over all the node specs
226+
for spec in collect_specs_from_graph_module(
227+
graph_module, alloc_graph_input, alloc_graph_output
228+
):
229+
if spec.lifetime[0] is None:
230+
continue
231+
232+
# lifetime is [start, end], both ends inclusive
233+
start, end = spec.lifetime
234+
byte_allocated[start] += spec.allocated_memory
235+
byte_allocated[end + 1] -= spec.allocated_memory
236+
237+
# accumulate the bytes allocated/deallocated to get memory usages
238+
memory_usages = list(itertools.accumulate(byte_allocated))
239+
240+
# find the peak memory usage and the index
241+
peak_memory_usage = max(memory_usages, default=0)
242+
peak_memory_usage_node_idx = (
243+
memory_usages.index(peak_memory_usage) if memory_usages else 0
244+
)
245+
246+
return peak_memory_usage, peak_memory_usage_node_idx
247+
248+
249+
# Print two tables with relevant memory planning information
250+
#
251+
# Per Memory Space Usage Table:
252+
# +--------------------------------------+----------------+-----------------------+-----------------------------+
253+
# | Memory Space | Base Address | Memory Size (Bytes) | Peak Memory Usage (Bytes) |
254+
# +======================================+================+=======================+=============================+
255+
# | MEMORY SPACE A | 0x57be0000 | 65213 | 64544 |
256+
# | MEMORY SPACE B | 0x57bf0000 | 65521 | 36864 |
257+
# | MEMORY SPACE ... | ... | ... | ... |
258+
# +--------------------------------------+----------------+-----------------------+-----------------------------+
259+
#
260+
# Total Memory Space Usage Table:
261+
# +-------------------------------------+---------------+---------+
262+
# | Peak memory usage across all spaces | 2380032 bytes | Node 86 |
263+
# +-------------------------------------+---------------+---------+
264+
def print_memory_planning_info(
265+
# pyre-fixme[11]: Annotation `ExecutorchProgramManager` is not defined as a type.
266+
executorch_prog: ExecutorchProgramManager,
267+
memory_config: MemoryConfig,
268+
alloc_graph_input: bool,
269+
alloc_graph_output: bool,
270+
) -> None:
271+
# Get the peak memory usages per memory space
272+
peak_memory_usages_per_memory = find_peak_memory_usages_per_memory(
273+
executorch_prog.exported_program().graph_module,
274+
alloc_graph_input,
275+
alloc_graph_output,
276+
)
277+
278+
# Create a table of memory spaces and their base addresses, total memory sizes, and peak memory usage
279+
memory_names, base_addrs = memory_config.memory_names, memory_config.base_addrs
280+
memory_usage_table = [
281+
[
282+
f"{(i + 1) if memory_names is None else memory_names[i]}",
283+
None if base_addrs is None else hex(base_addrs[i]),
284+
memory_config.memory_sizes[i],
285+
peak_memory_usages_per_memory[i],
286+
]
287+
for i in range(len(peak_memory_usages_per_memory))
288+
]
289+
290+
# Print the memory usage per memory space as a table
291+
logging.info(
292+
tabulate(
293+
memory_usage_table,
294+
headers=[
295+
"Memory Space",
296+
"Base Address",
297+
"Memory Size (Bytes)",
298+
"Peak Memory Usage (Bytes)",
299+
],
300+
tablefmt="outline",
301+
)
302+
)
303+
304+
# Get the total peak memory usage across all memory spaces
305+
total_peak_memory_usage = find_peak_memory_usage(
306+
executorch_prog.exported_program().graph_module,
307+
alloc_graph_input,
308+
alloc_graph_output,
309+
)
310+
311+
# Create a table with total peak memory usage and node at which this occurs
312+
total_memory_usage_table = [
313+
[
314+
"Peak memory usage across all spaces",
315+
f"{total_peak_memory_usage[0]} bytes",
316+
f"Node {total_peak_memory_usage[1]}",
317+
]
318+
]
319+
320+
# Print the total memory usage as a table
321+
print(
322+
tabulate(
323+
total_memory_usage_table,
324+
tablefmt="outline",
325+
)
326+
)
327+
328+
329+
class CadenceMemoryPlanning:
330+
def __init__(
331+
self,
332+
memory_config: MemoryConfig,
333+
mem_algo: int,
334+
alloc_graph_input: bool = True,
335+
alloc_graph_output: bool = True,
336+
) -> None:
337+
self._init_mem_algos()
338+
339+
self.memory_config = memory_config
340+
self.mem_algo = mem_algo
341+
self.alloc_graph_input = alloc_graph_input
342+
self.alloc_graph_output = alloc_graph_output
343+
344+
def _init_mem_algos(self) -> None:
345+
self.available_mem_algos = [
346+
position_based_greedy_with_hierarchy,
347+
greedy_by_size_for_offset_calculation_with_hierarchy,
348+
]
349+
350+
def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
351+
algo = partial(
352+
self.available_mem_algos[self.mem_algo],
353+
memory_config=self.memory_config,
354+
)
355+
# Create the memory planning pass. We allocate memory for input
356+
# (output) tensors if alloc_graph_input (alloc_graph_output) is
357+
# True.
358+
mem_planning = MemoryPlanningPass(
359+
algo,
360+
allow_lifetime_and_storage_overlap=False,
361+
alloc_graph_input=self.alloc_graph_input,
362+
alloc_graph_output=self.alloc_graph_output,
363+
)
364+
mem_planning(graph_module)
365+
366+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)