Skip to content

Commit 62018d2

Browse files
committed
[ET][Memory planning] Improve greedy memory planning.
This diff replaces the old greedy algorithm. Older algorithm resulted in 35% worse compared to theoretical optimum. THis matter for long context even more since additional overhead can be few hundred MB. For example the theorical optimial for llama3_2 8B, 4-bit quantized modelw ith context length of 2k needs about 1G of memory. This theoretcial max can be observed by looking at the peaks in memory profile. Current agorithm resulted in about 1.6GB of planned memory. New algorithm reduce that to about 1.1G. Differential Revision: [D68448332](https://our.internmc.facebook.com/intern/diff/D68448332/) ghstack-source-id: 262854265 Pull Request resolved: #7926
1 parent c7c4007 commit 62018d2

File tree

3 files changed

+185
-25
lines changed

3 files changed

+185
-25
lines changed

exir/memory_planning.py

Lines changed: 142 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import operator
1212
import typing
1313
from collections import defaultdict
14-
from dataclasses import dataclass
14+
from dataclasses import dataclass, field
1515
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
1616

1717
import torch
@@ -454,6 +454,18 @@ def update_all_tensors_lifetime(
454454
return specs
455455

456456

457+
@dataclass
458+
class AllocationSpec:
459+
"""
460+
AllocationSpec is used to represent the allocation of a tensor.
461+
"""
462+
463+
# The offset of the tensor in the shared object/pool.
464+
offset: int
465+
# TensorSpec
466+
spec: TensorSpec
467+
468+
457469
@dataclass
458470
class SharedObject:
459471
r"""
@@ -470,8 +482,15 @@ class SharedObject:
470482
offset: int
471483
# size of this shared object in bytes
472484
size: int
485+
# When the object is first created
486+
first_used_index: int
473487
# the object will be available for index (last_used_index + 1)
474488
last_used_index: int
489+
# list of allocations belong to this shared object
490+
allocations: List[AllocationSpec] = field(default_factory=list)
491+
492+
def __repr__(self) -> str:
493+
return f"SharedObject(idx={self.idx}, offset={self.offset}, size={self.size}, lifetime=[{self.first_used_index, self.last_used_index}])"
475494

476495

477496
def materialize_buffer(
@@ -489,35 +508,122 @@ def materialize_buffer(
489508
return total_size
490509

491510

492-
def _size_abs_dif(sobj: SharedObject, spec: TensorSpec) -> int:
511+
def _does_not_overlap(sobj: SharedObject, spec: TensorSpec) -> bool:
493512
r"""
494-
Calculate the absolute different between the size of a shared object and
495-
a tensor.
513+
Check if a shared object and a tensor do not overlap.
496514
"""
497-
return abs(sobj.size - spec.allocated_memory)
515+
for alloc in sobj.allocations:
516+
if not (
517+
spec.lifetime[1] < alloc.spec.lifetime[0]
518+
or spec.lifetime[0] > alloc.spec.lifetime[1]
519+
):
520+
return False
521+
return True
522+
523+
524+
def _find_max_overlapping_allocations_offset(
525+
sobj: SharedObject, spec: TensorSpec
526+
) -> int:
527+
max_offset = 0
528+
for alloc in sobj.allocations:
529+
if (
530+
spec.lifetime[1] < alloc.spec.lifetime[0]
531+
or spec.lifetime[0] > alloc.spec.lifetime[1]
532+
):
533+
continue
534+
max_offset = max(alloc.offset + alloc.spec.allocated_memory, max_offset)
535+
return max_offset
498536

499537

500538
def pick_shared_obj(
501539
shared_objects: List[SharedObject], spec: TensorSpec
502540
) -> SharedObject:
503541
r"""
504-
Pick the available shared object with closest size to the tensor.
505-
If there are no available shared object left, create a new one.
542+
Pick the available shared object to which to assign this spec,
543+
or create a new one
544+
Algorithm details
545+
Previous: Look at every spec in chronological order. Find if previously allocated object
546+
allows it to fit in. If not, allocate a new object.
547+
New:
548+
- Sort all the specs by allocation size
549+
- Process the specs in order
550+
- If the spec's size in smaller than previously allocated buckets:
551+
- Conditions under which previously allocated bucket can be used:
552+
- Lifetime of the spec does not overlap with lifetime of the bucket.
553+
- In this case allocate spec to that bucket and expand its lifetime.
554+
- Spec is allocated at offset = 0 in this bucket.
555+
- Add this spec to allocated object's list of specs.
556+
- Lifetime of the spec overlaps with lifetime of the bucket,
557+
partially or fully (e.g. spec's lifetime subset of bucket's lifetime)
558+
- If none of the specs in the bucket overlaps with spec's lifetime.
559+
- Allocate spec to the bucket at offset = 0.
560+
- Add this spec to the bucket's list of specs.
561+
- Expand bucket's lifetime accounting for added spec's lifetime.
562+
- If one or more specs in the bucket overlaps with spec's lifetime.
563+
- Collect offsets (at which the given overlapping spec is allocated in the bucket).
564+
of all the overlapping specs, and find the max offset.
565+
- Allocate spec to the bucket at offset = max_offset + max_offset_spec_size.
566+
- Add this spec to the bucket's list of specs.
567+
- Expand bucket's lifetime accounting for added spec's lifetime.
568+
- If none of these conditions are met, allocate a new bucket.
569+
- Add spec to this bucket.
570+
- Update bucket's lifetime to that of the spec.
571+
- If the spec's size is larger than previously allocated buckets, allocate a new bucket.
572+
- Size and lifetime of this bucket is that of the spec
573+
574+
Proof of correctness:
575+
- If allocating a new bucket, it is correct.
576+
- If allocating spec to an existing bucket, whose lifetime does not overlap with any
577+
of the previously allocated specs' lifetime, then the allocation is correct.
578+
Proof of correctness by induction when adding spec to an existing bucket:
579+
- If all previous allocations in the given bucket are correct:
580+
- Then the new one being added must be correct because when the requested allocation
581+
overlaps with one or more previous allocations, we find the largest offset among
582+
all the overlapping allocations, and allocate the new spec at that offset. Hence,
583+
the allocation at such an offset, will not overlap with any previous allocations.
584+
Base case: A newly added allocation within a bucket with single allocation is correct:
585+
because a) it must fit and b) its lifetime must not overlap with object's lifetime.
586+
This holds true because of the following invariants:
587+
- Once a bucket is created, it is never resized.
588+
- All the allocations within a bucket follow this:
589+
- Span, defined by allocation's offset + size, of two allocations can only overlap,
590+
if their timelines do not overlap.
506591
"""
507-
# TODO: do better than linear scan
508592
picked = None
509593
for sobj in shared_objects:
510-
if spec.lifetime[0] > sobj.last_used_index:
511-
if picked is None or _size_abs_dif(sobj, spec) < _size_abs_dif(
512-
picked, spec
513-
):
514-
picked = sobj
515-
sobj.last_used_index = spec.lifetime[1]
516-
sobj.size = max(sobj.size, spec.allocated_memory)
594+
if _does_not_overlap(sobj, spec):
595+
assert sobj.size >= spec.allocated_memory, "Allocation specs are not sorted"
596+
picked = sobj
597+
sobj.first_used_index = min(sobj.first_used_index, spec.lifetime[0])
598+
sobj.last_used_index = max(sobj.last_used_index, spec.lifetime[1])
599+
allocation_spec = AllocationSpec(0, spec)
600+
picked.allocations.append(allocation_spec)
601+
break
602+
603+
if picked is None:
604+
for sobj in shared_objects:
605+
max_offset = _find_max_overlapping_allocations_offset(sobj, spec)
606+
if max_offset > 0:
607+
if max_offset + spec.allocated_memory <= sobj.size:
608+
picked = sobj
609+
sobj.first_used_index = min(sobj.first_used_index, spec.lifetime[0])
610+
sobj.last_used_index = max(sobj.last_used_index, spec.lifetime[1])
611+
allocation_spec = AllocationSpec(max_offset, spec)
612+
picked.allocations.append(allocation_spec)
613+
break
614+
517615
if picked is None:
518616
picked = SharedObject(
519-
len(shared_objects), -1, spec.allocated_memory, spec.lifetime[1]
617+
len(shared_objects),
618+
-1,
619+
spec.allocated_memory,
620+
spec.lifetime[0],
621+
spec.lifetime[1],
520622
)
623+
allocation_spec = AllocationSpec(0, spec)
624+
picked.allocations.append(allocation_spec)
625+
picked.first_used_index = spec.lifetime[0]
626+
picked.last_used_index = spec.lifetime[1]
521627
shared_objects.append(picked)
522628

523629
return picked
@@ -565,13 +671,20 @@ def greedy(
565671
# For each tensor, pick the available shared object with closest size to
566672
# the tensor. If there are no available shared object left, create a new
567673
# one.
674+
import bisect
675+
676+
sorted_specs = []
568677
for spec in collect_specs_from_nodes(
569678
graph_module.graph.nodes,
570679
graph_signature,
571680
do_assertion=do_assertion,
572681
ignore_graph_input=not alloc_graph_input,
573682
ignore_graph_output=not alloc_graph_output,
574683
):
684+
bisect.insort(sorted_specs, spec, key=lambda x: x.allocated_memory)
685+
sorted_specs.reverse()
686+
687+
for spec in sorted_specs:
575688
if spec.mem_id is None:
576689
spec.mem_id = 1
577690
spec.realign(alignment)
@@ -583,6 +696,7 @@ def greedy(
583696
total_sizes = [0, 0]
584697
else:
585698
total_sizes = [0] * (max(shared_objects.keys()) + 1)
699+
num_specs_processed = 0
586700
for mem_id in shared_objects:
587701
input_total_size = 0
588702
if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None):
@@ -594,13 +708,18 @@ def greedy(
594708
total_sizes[mem_id] = materialize_buffer(
595709
shared_objects[mem_id], input_total_size
596710
)
597-
598-
# Since we now know the number of shared objects we need and the size of
599-
# each shared object, we can assign offset in the memory buffer for each
600-
# shared object.
601-
for spec, sobj in spec2obj.items():
602-
spec.mem_obj_id = sobj.idx
603-
spec.mem_offset = sobj.offset
711+
# Since we now know the number of shared objects we need and the size of
712+
# each shared object, we can assign offset in the memory buffer for each
713+
# shared object.
714+
for sobj in shared_objects[mem_id]:
715+
for alloc in sobj.allocations:
716+
spec = alloc.spec
717+
alloc.spec.mem_obj_id = sobj.idx
718+
alloc.spec.mem_offset = sobj.offset + alloc.offset
719+
num_specs_processed += 1
720+
assert (
721+
len(spec2obj) == num_specs_processed
722+
), f"All specs should be processed but there were {len(spec2obj)} specs and processed {num_specs_processed} specs"
604723

605724
logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}")
606725
return total_sizes

exir/passes/memory_planning_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,5 @@ def run(
127127
f"The {getattr(self.memory_planning_algo, '__name__', repr(self.memory_planning_algo))} algorithm reuses storage for {num_reuse_pairs} pair of tensors"
128128
)
129129
verifier.verify_graph_input_output()
130+
verifier.verify_storage_reuse()
130131
return PassResult(graph_module, True)

exir/tests/test_memory_planning.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,28 @@ def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
106106
return (torch.randn(2),)
107107

108108

109+
class LinearsWithDifferentSizeAndViewOps(torch.nn.Module):
110+
def __init__(self) -> None:
111+
super(LinearsWithDifferentSizeAndViewOps, self).__init__()
112+
self.linears = torch.nn.ModuleList()
113+
for x in [8, 16, 32, 64]:
114+
self.linears.append(torch.nn.Linear(x, x * 2))
115+
116+
def forward(self, i: torch.Tensor) -> torch.Tensor:
117+
o1 = i
118+
for linear in self.linears:
119+
o1 = linear(o1)
120+
o1 = o1.view(-1, 64, 2)
121+
o1 = o1 + 1
122+
o2 = i
123+
for linear in self.linears:
124+
o2 = linear(o2)
125+
return o1.view(-1, 128) + o2
126+
127+
def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
128+
return (torch.randn(3, 8),)
129+
130+
109131
class ModuleReturnTwo(nn.Module):
110132
def __init__(self) -> None:
111133
super(ModuleReturnTwo, self).__init__()
@@ -360,6 +382,13 @@ def verify_overlap_placeholders(
360382
],
361383
)
362384

385+
test_linear_with_view: Callable[..., None] = maketest(
386+
LinearsWithDifferentSizeAndViewOps,
387+
criteria=[
388+
(greedy, True),
389+
],
390+
)
391+
363392
# greedy algorithm will reuse memory if we let the algorithm allocate
364393
# memory for both graph input and output.
365394
test_list_arg: Callable[..., None] = maketest(
@@ -508,15 +537,26 @@ def test_multiple_pools(
508537
verifier.verify_graph_input_output()
509538

510539
idx = 0
540+
reference_output = dict()
541+
actual_output = dict()
511542
for node in graph_module.graph.nodes:
512543
if node.op == "placeholder" or (
513544
node.op == "call_function"
514545
and node.target in (torch.ops.aten.add.out, torch.ops.aten.mul.out)
515546
):
516547
mem_id, mem_offset = expected_allocs[idx]
517-
self.assertEqual(node.meta["spec"].mem_id, mem_id)
518-
self.assertEqual(node.meta["spec"].mem_offset, mem_offset)
548+
actual_mem_id, actual_mem_offset = (
549+
node.meta["spec"].mem_id,
550+
node.meta["spec"].mem_offset,
551+
)
552+
if (mem_id, mem_offset) not in reference_output:
553+
reference_output[(mem_id, mem_offset)] = 1
554+
actual_output[(actual_mem_id, actual_mem_offset)] = 1
555+
else:
556+
reference_output[(mem_id, mem_offset)] += 1
557+
actual_output[(actual_mem_id, actual_mem_offset)] += 1
519558
idx += 1
559+
self.assertEqual(reference_output, actual_output)
520560
self.assertEqual(graph_module.meta["non_const_buffer_sizes"], expected_bufsizes)
521561

522562
def test_constants_not_memory_planned(self) -> None:

0 commit comments

Comments
 (0)