Skip to content

Commit 510d795

Browse files
committed
[ET][Memory planning] Improve greedy memory planning.
Pull Request resolved: #7926 This diff replaces the old greedy algorithm. Older algorithm resulted in 35% worse compared to theoretical optimum. THis matter for long context even more since additional overhead can be few hundred MB. For example the theorical optimial for llama3_2 8B, 4-bit quantized modelw ith context length of 2k needs about 1G of memory. This theoretcial max can be observed by looking at the peaks in memory profile. Current agorithm resulted in about 1.6GB of planned memory. New algorithm reduce that to about 1.1G. ghstack-source-id: 263342052 @exported-using-ghexport Differential Revision: [D68448332](https://our.internmc.facebook.com/intern/diff/D68448332/)
1 parent 801bde1 commit 510d795

File tree

5 files changed

+270
-33
lines changed

5 files changed

+270
-33
lines changed

backends/vulkan/vulkan_preprocess.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-strict
88

9+
from functools import partial
10+
911
from typing import Any, Dict, final, List
1012

1113
import executorch.backends.vulkan.utils as utils
@@ -17,7 +19,6 @@
1719
from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass
1820
from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass
1921
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
20-
2122
from executorch.backends.vulkan._passes import (
2223
insert_prepack_nodes,
2324
RemoveLocalScalarDenseOpsTransform,
@@ -41,6 +42,8 @@
4142
PreprocessResult,
4243
)
4344
from executorch.exir.backend.utils import DelegateMappingBuilder
45+
46+
from executorch.exir.memory_planning import greedy
4447
from executorch.exir.pass_base import ExportPass, PassBase
4548

4649
from executorch.exir.passes import MemoryPlanningPass, SpecPropPass
@@ -189,11 +192,12 @@ def preprocess( # noqa: C901
189192

190193
# Finally, apply dynamic shape passes and memory planning pass. These passes
191194
# must be applied only when the graph structure is finalized.
195+
greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False)
192196
program = apply_passes(
193197
program,
194198
[
195199
ConstraintBasedSymShapeEvalPass(),
196-
MemoryPlanningPass(),
200+
MemoryPlanningPass(memory_planning_algo=greedy_memory_planning),
197201
],
198202
)
199203

exir/memory_planning.py

Lines changed: 199 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import operator
1212
import typing
1313
from collections import defaultdict
14-
from dataclasses import dataclass
14+
from dataclasses import dataclass, field
1515
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
1616

1717
import torch
@@ -117,6 +117,17 @@ def storage_overlap(cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec) -> bool:
117117

118118
return has_overlap
119119

120+
@classmethod
121+
def _debug_message_from_specs(
122+
cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec
123+
) -> str:
124+
message = (
125+
f"lhs life time: {lhs_spec.lifetime}, rhs lifetime: {rhs_spec.lifetime} "
126+
)
127+
message += f"lhs: mem_id {lhs_spec.mem_id} storage: {lhs_spec.mem_offset}, {lhs_spec.allocated_memory} "
128+
message += f"rhs: mem_id {rhs_spec.mem_id} storage: {rhs_spec.mem_offset}, {rhs_spec.allocated_memory}"
129+
return message
130+
120131
def verify_storage_reuse(
121132
self, allow_lifetime_and_storage_overlap: bool = False
122133
) -> int:
@@ -159,7 +170,7 @@ def verify_storage_reuse(
159170
lhs_spec, rhs_spec
160171
):
161172
raise InternalError(
162-
f"Unexpected storage overlap: lhs {lhs_spec}, rhs {rhs_spec}"
173+
f"Unexpected storage overlap: {Verifier._debug_message_from_specs(lhs_spec, rhs_spec)}"
163174
)
164175

165176
# Check that each mem_obj_id is consistent with whether the tensors have
@@ -454,6 +465,18 @@ def update_all_tensors_lifetime(
454465
return specs
455466

456467

468+
@dataclass
469+
class AllocationSpec:
470+
"""
471+
AllocationSpec is used to represent the allocation of a tensor.
472+
"""
473+
474+
# The offset of the tensor in the shared object/pool.
475+
offset: int
476+
# TensorSpec
477+
spec: TensorSpec
478+
479+
457480
@dataclass
458481
class SharedObject:
459482
r"""
@@ -470,8 +493,15 @@ class SharedObject:
470493
offset: int
471494
# size of this shared object in bytes
472495
size: int
496+
# When the object is first created
497+
first_used_index: int
473498
# the object will be available for index (last_used_index + 1)
474499
last_used_index: int
500+
# list of allocations belong to this shared object
501+
allocations: List[AllocationSpec] = field(default_factory=list)
502+
503+
def __repr__(self) -> str:
504+
return f"SharedObject(idx={self.idx}, offset={self.offset}, size={self.size}, lifetime=[{self.first_used_index, self.last_used_index}])"
475505

476506

477507
def materialize_buffer(
@@ -489,35 +519,124 @@ def materialize_buffer(
489519
return total_size
490520

491521

492-
def _size_abs_dif(sobj: SharedObject, spec: TensorSpec) -> int:
522+
def _does_not_overlap(sobj: SharedObject, spec: TensorSpec) -> bool:
493523
r"""
494-
Calculate the absolute different between the size of a shared object and
495-
a tensor.
524+
Check if a shared object and a tensor do not overlap.
496525
"""
497-
return abs(sobj.size - spec.allocated_memory)
526+
for alloc in sobj.allocations:
527+
if not (
528+
spec.lifetime[1] < alloc.spec.lifetime[0]
529+
or spec.lifetime[0] > alloc.spec.lifetime[1]
530+
):
531+
return False
532+
return True
533+
534+
535+
def _find_max_overlapping_allocations_offset(
536+
sobj: SharedObject, spec: TensorSpec
537+
) -> int:
538+
max_offset = 0
539+
for alloc in sobj.allocations:
540+
if (
541+
spec.lifetime[1] < alloc.spec.lifetime[0]
542+
or spec.lifetime[0] > alloc.spec.lifetime[1]
543+
):
544+
continue
545+
max_offset = max(alloc.offset + alloc.spec.allocated_memory, max_offset)
546+
return max_offset
498547

499548

500549
def pick_shared_obj(
501-
shared_objects: List[SharedObject], spec: TensorSpec
550+
shared_objects: List[SharedObject],
551+
spec: TensorSpec,
552+
allow_overlapping_allocations: bool = True,
502553
) -> SharedObject:
503554
r"""
504-
Pick the available shared object with closest size to the tensor.
505-
If there are no available shared object left, create a new one.
555+
Pick the available shared object to which to assign this spec,
556+
or create a new one
557+
Algorithm details
558+
Previous: Look at every spec in chronological order. Find if previously allocated object
559+
allows it to fit in. If not, allocate a new object.
560+
New:
561+
- Sort all the specs by allocation size
562+
- Process the specs in order
563+
- If the spec's size in smaller than previously allocated buckets:
564+
- Conditions under which previously allocated bucket can be used:
565+
- Lifetime of the spec does not overlap with lifetime of the bucket.
566+
- In this case allocate spec to that bucket and expand its lifetime.
567+
- Spec is allocated at offset = 0 in this bucket.
568+
- Add this spec to allocated object's list of specs.
569+
- Lifetime of the spec overlaps with lifetime of the bucket,
570+
partially or fully (e.g. spec's lifetime subset of bucket's lifetime)
571+
- If none of the specs in the bucket overlaps with spec's lifetime.
572+
- Allocate spec to the bucket at offset = 0.
573+
- Add this spec to the bucket's list of specs.
574+
- Expand bucket's lifetime accounting for added spec's lifetime.
575+
- If one or more specs in the bucket overlaps with spec's lifetime.
576+
- Collect offsets (at which the given overlapping spec is allocated in the bucket).
577+
of all the overlapping specs, and find the max offset.
578+
- Allocate spec to the bucket at offset = max_offset + max_offset_spec_size.
579+
- Add this spec to the bucket's list of specs.
580+
- Expand bucket's lifetime accounting for added spec's lifetime.
581+
- If none of these conditions are met, allocate a new bucket.
582+
- Add spec to this bucket.
583+
- Update bucket's lifetime to that of the spec.
584+
- If the spec's size is larger than previously allocated buckets, allocate a new bucket.
585+
- Size and lifetime of this bucket is that of the spec
586+
587+
Proof of correctness:
588+
- If allocating a new bucket, it is correct.
589+
- If allocating spec to an existing bucket, whose lifetime does not overlap with any
590+
of the previously allocated specs' lifetime, then the allocation is correct.
591+
Proof of correctness by induction when adding spec to an existing bucket:
592+
- If all previous allocations in the given bucket are correct:
593+
- Then the new one being added must be correct because when the requested allocation
594+
overlaps with one or more previous allocations, we find the largest offset among
595+
all the overlapping allocations, and allocate the new spec at that offset. Hence,
596+
the allocation at such an offset, will not overlap with any previous allocations.
597+
Base case: A newly added allocation within a bucket with single allocation is correct:
598+
because a) it must fit and b) its lifetime must not overlap with object's lifetime.
599+
This holds true because of the following invariants:
600+
- Once a bucket is created, it is never resized.
601+
- All the allocations within a bucket follow this:
602+
- Span, defined by allocation's offset + size, of two allocations can only overlap,
603+
if their timelines do not overlap.
506604
"""
507-
# TODO: do better than linear scan
508605
picked = None
509606
for sobj in shared_objects:
510-
if spec.lifetime[0] > sobj.last_used_index:
511-
if picked is None or _size_abs_dif(sobj, spec) < _size_abs_dif(
512-
picked, spec
513-
):
514-
picked = sobj
515-
sobj.last_used_index = spec.lifetime[1]
516-
sobj.size = max(sobj.size, spec.allocated_memory)
607+
if _does_not_overlap(sobj, spec):
608+
assert sobj.size >= spec.allocated_memory, "Allocation specs are not sorted"
609+
picked = sobj
610+
sobj.first_used_index = min(sobj.first_used_index, spec.lifetime[0])
611+
sobj.last_used_index = max(sobj.last_used_index, spec.lifetime[1])
612+
allocation_spec = AllocationSpec(0, spec)
613+
picked.allocations.append(allocation_spec)
614+
break
615+
616+
if picked is None and allow_overlapping_allocations:
617+
for sobj in shared_objects:
618+
max_offset = _find_max_overlapping_allocations_offset(sobj, spec)
619+
if max_offset > 0:
620+
if max_offset + spec.allocated_memory <= sobj.size:
621+
picked = sobj
622+
sobj.first_used_index = min(sobj.first_used_index, spec.lifetime[0])
623+
sobj.last_used_index = max(sobj.last_used_index, spec.lifetime[1])
624+
allocation_spec = AllocationSpec(max_offset, spec)
625+
picked.allocations.append(allocation_spec)
626+
break
627+
517628
if picked is None:
518629
picked = SharedObject(
519-
len(shared_objects), -1, spec.allocated_memory, spec.lifetime[1]
630+
len(shared_objects),
631+
-1,
632+
spec.allocated_memory,
633+
spec.lifetime[0],
634+
spec.lifetime[1],
520635
)
636+
allocation_spec = AllocationSpec(0, spec)
637+
picked.allocations.append(allocation_spec)
638+
picked.first_used_index = spec.lifetime[0]
639+
picked.last_used_index = spec.lifetime[1]
521640
shared_objects.append(picked)
522641

523642
return picked
@@ -550,13 +669,50 @@ def get_node_tensor_specs(
550669
]
551670

552671

672+
# Little bit hacky to check if the graph contains
673+
# XNNPACK delegate
674+
# Why?
675+
676+
677+
def _contains_xnnpack_delegate(graph_module: torch.fx.GraphModule) -> bool:
678+
for node in graph_module.graph.nodes:
679+
if node.target == executorch_call_delegate:
680+
lowered_module = getattr(
681+
graph_module.graph.owning_module, node.args[0].target
682+
)
683+
if "xnnpack" in lowered_module.backend_id.lower():
684+
return True
685+
return False
686+
687+
553688
def greedy(
554689
graph_module: torch.fx.GraphModule,
555690
alignment: int,
556691
graph_signature: Optional[ExportGraphSignature] = None,
557692
alloc_graph_input: bool = True,
558693
alloc_graph_output: bool = True,
694+
allow_overlapping_allocations: bool = True,
559695
) -> List[int]:
696+
r"""Greedy algorithm to allocate memory for tensors in the graph.
697+
alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
698+
alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
699+
allow_overlapping_allocations: If set to true, allows for allocations that overlap
700+
in their lifetime but are at different offsets in the storage. By default true.
701+
This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
702+
allocations disabled
703+
"""
704+
# padding allocation with 64 bytes.
705+
# this requirement is really for XNNPACK backend which can read tensors
706+
# beyond the end of the tensor. This is done for performance
707+
# optimizations in XNNPACK.
708+
# While accounting for backend specific requirement is not the right choice
709+
# in backend agnostic memory planning, we do it here as it seems most appropriate.
710+
# Right now this applies to greedy only so any other
711+
# algorithm that plans memory for XNNPACK backend will
712+
# not have this.
713+
extra_padded_bytes = 0
714+
if _contains_xnnpack_delegate(graph_module):
715+
extra_padded_bytes = 64
560716
spec2obj = {}
561717
shared_objects = defaultdict(list)
562718
# Don't do assertion in collect_specs_from_nodes if we have already encountered
@@ -565,24 +721,34 @@ def greedy(
565721
# For each tensor, pick the available shared object with closest size to
566722
# the tensor. If there are no available shared object left, create a new
567723
# one.
724+
import bisect
725+
726+
sorted_specs = []
568727
for spec in collect_specs_from_nodes(
569728
graph_module.graph.nodes,
570729
graph_signature,
571730
do_assertion=do_assertion,
572731
ignore_graph_input=not alloc_graph_input,
573732
ignore_graph_output=not alloc_graph_output,
574733
):
734+
bisect.insort(sorted_specs, spec, key=lambda x: x.allocated_memory)
735+
sorted_specs.reverse()
736+
737+
for spec in sorted_specs:
575738
if spec.mem_id is None:
576739
spec.mem_id = 1
577740
spec.realign(alignment)
578-
spec2obj[spec] = pick_shared_obj(shared_objects[spec.mem_id], spec)
741+
spec2obj[spec] = pick_shared_obj(
742+
shared_objects[spec.mem_id], spec, allow_overlapping_allocations
743+
)
579744

580745
if len(shared_objects) == 0:
581746
# Cannot find any tensor in the graph that needs to be allocated.
582747
# Return [0, 0] to be consistent with default behavior of naive.
583748
total_sizes = [0, 0]
584749
else:
585750
total_sizes = [0] * (max(shared_objects.keys()) + 1)
751+
num_specs_processed = 0
586752
for mem_id in shared_objects:
587753
input_total_size = 0
588754
if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None):
@@ -594,13 +760,20 @@ def greedy(
594760
total_sizes[mem_id] = materialize_buffer(
595761
shared_objects[mem_id], input_total_size
596762
)
597-
598-
# Since we now know the number of shared objects we need and the size of
599-
# each shared object, we can assign offset in the memory buffer for each
600-
# shared object.
601-
for spec, sobj in spec2obj.items():
602-
spec.mem_obj_id = sobj.idx
603-
spec.mem_offset = sobj.offset
763+
total_sizes[mem_id] += extra_padded_bytes
764+
765+
# Since we now know the number of shared objects we need and the size of
766+
# each shared object, we can assign offset in the memory buffer for each
767+
# shared object.
768+
for sobj in shared_objects[mem_id]:
769+
for alloc in sobj.allocations:
770+
spec = alloc.spec
771+
alloc.spec.mem_obj_id = sobj.idx
772+
alloc.spec.mem_offset = sobj.offset + alloc.offset
773+
num_specs_processed += 1
774+
assert (
775+
len(spec2obj) == num_specs_processed
776+
), f"All specs should be processed but there were {len(spec2obj)} specs and processed {num_specs_processed} specs"
604777

605778
logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}")
606779
return total_sizes

0 commit comments

Comments
 (0)