Skip to content

Commit adeaa31

Browse files
committed
[ET][Memory planning] Improve greedy memory planning.
Pull Request resolved: #7926 This diff replaces the old greedy algorithm. Older algorithm resulted in 35% worse compared to theoretical optimum. THis matter for long context even more since additional overhead can be few hundred MB. For example the theorical optimial for llama3_2 8B, 4-bit quantized modelw ith context length of 2k needs about 1G of memory. This theoretcial max can be observed by looking at the peaks in memory profile. Current agorithm resulted in about 1.6GB of planned memory. New algorithm reduce that to about 1.1G. ghstack-source-id: 263000139 @exported-using-ghexport Differential Revision: [D68448332](https://our.internmc.facebook.com/intern/diff/D68448332/)
1 parent c7c4007 commit adeaa31

File tree

5 files changed

+248
-32
lines changed

5 files changed

+248
-32
lines changed

backends/vulkan/vulkan_preprocess.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-strict
88

9+
from functools import partial
10+
911
from typing import Any, Dict, final, List
1012

1113
import executorch.backends.vulkan.utils as utils
@@ -18,6 +20,9 @@
1820
from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass
1921
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
2022

23+
from executorch.exir.memory_planning import (
24+
greedy,
25+
)
2126
from executorch.backends.vulkan._passes import (
2227
insert_prepack_nodes,
2328
RemoveLocalScalarDenseOpsTransform,
@@ -189,11 +194,12 @@ def preprocess( # noqa: C901
189194

190195
# Finally, apply dynamic shape passes and memory planning pass. These passes
191196
# must be applied only when the graph structure is finalized.
197+
greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False)
192198
program = apply_passes(
193199
program,
194200
[
195201
ConstraintBasedSymShapeEvalPass(),
196-
MemoryPlanningPass(),
202+
MemoryPlanningPass(memory_planning_algo=greedy_memory_planning),
197203
],
198204
)
199205

exir/memory_planning.py

Lines changed: 176 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import operator
1212
import typing
1313
from collections import defaultdict
14-
from dataclasses import dataclass
14+
from dataclasses import dataclass, field
1515
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
1616

1717
import torch
@@ -117,6 +117,17 @@ def storage_overlap(cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec) -> bool:
117117

118118
return has_overlap
119119

120+
@classmethod
121+
def _debug_message_from_specs(
122+
cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec
123+
) -> str:
124+
message = (
125+
f"lhs life time: {lhs_spec.lifetime}, rhs lifetime: {rhs_spec.lifetime} "
126+
)
127+
message += f"lhs: mem_id {lhs_spec.mem_id} storage: {lhs_spec.mem_offset}, {lhs_spec.allocated_memory} "
128+
message += f"rhs: mem_id {rhs_spec.mem_id} storage: {rhs_spec.mem_offset}, {rhs_spec.allocated_memory}"
129+
return message
130+
120131
def verify_storage_reuse(
121132
self, allow_lifetime_and_storage_overlap: bool = False
122133
) -> int:
@@ -159,7 +170,7 @@ def verify_storage_reuse(
159170
lhs_spec, rhs_spec
160171
):
161172
raise InternalError(
162-
f"Unexpected storage overlap: lhs {lhs_spec}, rhs {rhs_spec}"
173+
f"Unexpected storage overlap: {Verifier._debug_message_from_specs(lhs_spec, rhs_spec)}"
163174
)
164175

165176
# Check that each mem_obj_id is consistent with whether the tensors have
@@ -454,6 +465,18 @@ def update_all_tensors_lifetime(
454465
return specs
455466

456467

468+
@dataclass
469+
class AllocationSpec:
470+
"""
471+
AllocationSpec is used to represent the allocation of a tensor.
472+
"""
473+
474+
# The offset of the tensor in the shared object/pool.
475+
offset: int
476+
# TensorSpec
477+
spec: TensorSpec
478+
479+
457480
@dataclass
458481
class SharedObject:
459482
r"""
@@ -470,8 +493,15 @@ class SharedObject:
470493
offset: int
471494
# size of this shared object in bytes
472495
size: int
496+
# When the object is first created
497+
first_used_index: int
473498
# the object will be available for index (last_used_index + 1)
474499
last_used_index: int
500+
# list of allocations belong to this shared object
501+
allocations: List[AllocationSpec] = field(default_factory=list)
502+
503+
def __repr__(self) -> str:
504+
return f"SharedObject(idx={self.idx}, offset={self.offset}, size={self.size}, lifetime=[{self.first_used_index, self.last_used_index}])"
475505

476506

477507
def materialize_buffer(
@@ -489,35 +519,124 @@ def materialize_buffer(
489519
return total_size
490520

491521

492-
def _size_abs_dif(sobj: SharedObject, spec: TensorSpec) -> int:
522+
def _does_not_overlap(sobj: SharedObject, spec: TensorSpec) -> bool:
493523
r"""
494-
Calculate the absolute different between the size of a shared object and
495-
a tensor.
524+
Check if a shared object and a tensor do not overlap.
496525
"""
497-
return abs(sobj.size - spec.allocated_memory)
526+
for alloc in sobj.allocations:
527+
if not (
528+
spec.lifetime[1] < alloc.spec.lifetime[0]
529+
or spec.lifetime[0] > alloc.spec.lifetime[1]
530+
):
531+
return False
532+
return True
533+
534+
535+
def _find_max_overlapping_allocations_offset(
536+
sobj: SharedObject, spec: TensorSpec
537+
) -> int:
538+
max_offset = 0
539+
for alloc in sobj.allocations:
540+
if (
541+
spec.lifetime[1] < alloc.spec.lifetime[0]
542+
or spec.lifetime[0] > alloc.spec.lifetime[1]
543+
):
544+
continue
545+
max_offset = max(alloc.offset + alloc.spec.allocated_memory, max_offset)
546+
return max_offset
498547

499548

500549
def pick_shared_obj(
501-
shared_objects: List[SharedObject], spec: TensorSpec
550+
shared_objects: List[SharedObject],
551+
spec: TensorSpec,
552+
allow_overlapping_allocations: bool = True,
502553
) -> SharedObject:
503554
r"""
504-
Pick the available shared object with closest size to the tensor.
505-
If there are no available shared object left, create a new one.
555+
Pick the available shared object to which to assign this spec,
556+
or create a new one
557+
Algorithm details
558+
Previous: Look at every spec in chronological order. Find if previously allocated object
559+
allows it to fit in. If not, allocate a new object.
560+
New:
561+
- Sort all the specs by allocation size
562+
- Process the specs in order
563+
- If the spec's size in smaller than previously allocated buckets:
564+
- Conditions under which previously allocated bucket can be used:
565+
- Lifetime of the spec does not overlap with lifetime of the bucket.
566+
- In this case allocate spec to that bucket and expand its lifetime.
567+
- Spec is allocated at offset = 0 in this bucket.
568+
- Add this spec to allocated object's list of specs.
569+
- Lifetime of the spec overlaps with lifetime of the bucket,
570+
partially or fully (e.g. spec's lifetime subset of bucket's lifetime)
571+
- If none of the specs in the bucket overlaps with spec's lifetime.
572+
- Allocate spec to the bucket at offset = 0.
573+
- Add this spec to the bucket's list of specs.
574+
- Expand bucket's lifetime accounting for added spec's lifetime.
575+
- If one or more specs in the bucket overlaps with spec's lifetime.
576+
- Collect offsets (at which the given overlapping spec is allocated in the bucket).
577+
of all the overlapping specs, and find the max offset.
578+
- Allocate spec to the bucket at offset = max_offset + max_offset_spec_size.
579+
- Add this spec to the bucket's list of specs.
580+
- Expand bucket's lifetime accounting for added spec's lifetime.
581+
- If none of these conditions are met, allocate a new bucket.
582+
- Add spec to this bucket.
583+
- Update bucket's lifetime to that of the spec.
584+
- If the spec's size is larger than previously allocated buckets, allocate a new bucket.
585+
- Size and lifetime of this bucket is that of the spec
586+
587+
Proof of correctness:
588+
- If allocating a new bucket, it is correct.
589+
- If allocating spec to an existing bucket, whose lifetime does not overlap with any
590+
of the previously allocated specs' lifetime, then the allocation is correct.
591+
Proof of correctness by induction when adding spec to an existing bucket:
592+
- If all previous allocations in the given bucket are correct:
593+
- Then the new one being added must be correct because when the requested allocation
594+
overlaps with one or more previous allocations, we find the largest offset among
595+
all the overlapping allocations, and allocate the new spec at that offset. Hence,
596+
the allocation at such an offset, will not overlap with any previous allocations.
597+
Base case: A newly added allocation within a bucket with single allocation is correct:
598+
because a) it must fit and b) its lifetime must not overlap with object's lifetime.
599+
This holds true because of the following invariants:
600+
- Once a bucket is created, it is never resized.
601+
- All the allocations within a bucket follow this:
602+
- Span, defined by allocation's offset + size, of two allocations can only overlap,
603+
if their timelines do not overlap.
506604
"""
507-
# TODO: do better than linear scan
508605
picked = None
509606
for sobj in shared_objects:
510-
if spec.lifetime[0] > sobj.last_used_index:
511-
if picked is None or _size_abs_dif(sobj, spec) < _size_abs_dif(
512-
picked, spec
513-
):
514-
picked = sobj
515-
sobj.last_used_index = spec.lifetime[1]
516-
sobj.size = max(sobj.size, spec.allocated_memory)
607+
if _does_not_overlap(sobj, spec):
608+
assert sobj.size >= spec.allocated_memory, "Allocation specs are not sorted"
609+
picked = sobj
610+
sobj.first_used_index = min(sobj.first_used_index, spec.lifetime[0])
611+
sobj.last_used_index = max(sobj.last_used_index, spec.lifetime[1])
612+
allocation_spec = AllocationSpec(0, spec)
613+
picked.allocations.append(allocation_spec)
614+
break
615+
616+
if picked is None and allow_overlapping_allocations:
617+
for sobj in shared_objects:
618+
max_offset = _find_max_overlapping_allocations_offset(sobj, spec)
619+
if max_offset > 0:
620+
if max_offset + spec.allocated_memory <= sobj.size:
621+
picked = sobj
622+
sobj.first_used_index = min(sobj.first_used_index, spec.lifetime[0])
623+
sobj.last_used_index = max(sobj.last_used_index, spec.lifetime[1])
624+
allocation_spec = AllocationSpec(max_offset, spec)
625+
picked.allocations.append(allocation_spec)
626+
break
627+
517628
if picked is None:
518629
picked = SharedObject(
519-
len(shared_objects), -1, spec.allocated_memory, spec.lifetime[1]
630+
len(shared_objects),
631+
-1,
632+
spec.allocated_memory,
633+
spec.lifetime[0],
634+
spec.lifetime[1],
520635
)
636+
allocation_spec = AllocationSpec(0, spec)
637+
picked.allocations.append(allocation_spec)
638+
picked.first_used_index = spec.lifetime[0]
639+
picked.last_used_index = spec.lifetime[1]
521640
shared_objects.append(picked)
522641

523642
return picked
@@ -556,7 +675,16 @@ def greedy(
556675
graph_signature: Optional[ExportGraphSignature] = None,
557676
alloc_graph_input: bool = True,
558677
alloc_graph_output: bool = True,
678+
allow_overlapping_allocations: bool = True,
559679
) -> List[int]:
680+
r"""Greedy algorithm to allocate memory for tensors in the graph.
681+
alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
682+
alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
683+
allow_overlapping_allocations: If set to true, allows for allocations that overlap
684+
in their lifetime but are at different offsets in the storage. By default true.
685+
This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
686+
allocations disabled
687+
"""
560688
spec2obj = {}
561689
shared_objects = defaultdict(list)
562690
# Don't do assertion in collect_specs_from_nodes if we have already encountered
@@ -565,24 +693,34 @@ def greedy(
565693
# For each tensor, pick the available shared object with closest size to
566694
# the tensor. If there are no available shared object left, create a new
567695
# one.
696+
import bisect
697+
698+
sorted_specs = []
568699
for spec in collect_specs_from_nodes(
569700
graph_module.graph.nodes,
570701
graph_signature,
571702
do_assertion=do_assertion,
572703
ignore_graph_input=not alloc_graph_input,
573704
ignore_graph_output=not alloc_graph_output,
574705
):
706+
bisect.insort(sorted_specs, spec, key=lambda x: x.allocated_memory)
707+
sorted_specs.reverse()
708+
709+
for spec in sorted_specs:
575710
if spec.mem_id is None:
576711
spec.mem_id = 1
577712
spec.realign(alignment)
578-
spec2obj[spec] = pick_shared_obj(shared_objects[spec.mem_id], spec)
713+
spec2obj[spec] = pick_shared_obj(
714+
shared_objects[spec.mem_id], spec, allow_overlapping_allocations
715+
)
579716

580717
if len(shared_objects) == 0:
581718
# Cannot find any tensor in the graph that needs to be allocated.
582719
# Return [0, 0] to be consistent with default behavior of naive.
583720
total_sizes = [0, 0]
584721
else:
585722
total_sizes = [0] * (max(shared_objects.keys()) + 1)
723+
num_specs_processed = 0
586724
for mem_id in shared_objects:
587725
input_total_size = 0
588726
if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None):
@@ -594,13 +732,25 @@ def greedy(
594732
total_sizes[mem_id] = materialize_buffer(
595733
shared_objects[mem_id], input_total_size
596734
)
597-
598-
# Since we now know the number of shared objects we need and the size of
599-
# each shared object, we can assign offset in the memory buffer for each
600-
# shared object.
601-
for spec, sobj in spec2obj.items():
602-
spec.mem_obj_id = sobj.idx
603-
spec.mem_offset = sobj.offset
735+
# padding allocation with 64 bytes.
736+
# this requirement really for XNNPACK backend which can access tensors
737+
# for reading beyond the end of the tensor. This is done for performance
738+
# optimizations in XNNPACK.
739+
# While account for backend specific requirement is not the right choice
740+
# in backend agnostic memory planning, we do it here for now.
741+
total_sizes[mem_id] += 64
742+
# Since we now know the number of shared objects we need and the size of
743+
# each shared object, we can assign offset in the memory buffer for each
744+
# shared object.
745+
for sobj in shared_objects[mem_id]:
746+
for alloc in sobj.allocations:
747+
spec = alloc.spec
748+
alloc.spec.mem_obj_id = sobj.idx
749+
alloc.spec.mem_offset = sobj.offset + alloc.offset
750+
num_specs_processed += 1
751+
assert (
752+
len(spec2obj) == num_specs_processed
753+
), f"All specs should be processed but there were {len(spec2obj)} specs and processed {num_specs_processed} specs"
604754

605755
logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}")
606756
return total_sizes

exir/passes/memory_planning_pass.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
import logging
88
import warnings
9-
from typing import Callable, List, Optional
9+
from typing import Any, Callable, List, Optional
10+
from functools import partial
1011

1112
import torch
1213
from executorch.exir.error import internal_assert
@@ -24,6 +25,17 @@
2425
from torch.export.exported_program import ExportGraphSignature
2526

2627

28+
# copied from https://stackoverflow.com/questions/75582932/python-how-can-i-print-the-function-name-of-a-partial-function
29+
def _callable_name(any_callable: Callable[..., Any]) -> str:
30+
if isinstance(any_callable, partial):
31+
return any_callable.func.__name__
32+
33+
try:
34+
return any_callable.__name__
35+
except AttributeError:
36+
return str(any_callable)
37+
38+
2739
class MemoryPlanningPass(PassBase):
2840
def __init__(
2941
self,
@@ -127,4 +139,12 @@ def run(
127139
f"The {getattr(self.memory_planning_algo, '__name__', repr(self.memory_planning_algo))} algorithm reuses storage for {num_reuse_pairs} pair of tensors"
128140
)
129141
verifier.verify_graph_input_output()
142+
if (
143+
callable(self.memory_planning_algo)
144+
and _callable_name(self.memory_planning_algo) == "greedy"
145+
):
146+
# Only verify storage reuse for greedy algorithm
147+
# At the moment cadence backends memory planning fails this
148+
# I dont know if that is a valid thing but if it is we should adjust verify_storage_reuse function
149+
verifier.verify_storage_reuse()
130150
return PassResult(graph_module, True)

exir/tests/test_joint_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,13 @@ def forward(self, x, y):
8484
et.executorch_program.execution_plan[0]
8585
.values[0]
8686
.val.allocation_info.memory_offset_low,
87-
0,
87+
96,
8888
)
8989
self.assertEqual(
9090
et.executorch_program.execution_plan[0]
9191
.values[1]
9292
.val.allocation_info.memory_offset_low,
93-
48,
93+
224,
9494
)
9595

9696
loss = m(*example_inputs)

0 commit comments

Comments
 (0)