Skip to content

Commit 8d4749e

Browse files
authored
[ET][Memory planning] Improve greedy memory planning.
Differential Revision: D68448332 Pull Request resolved: #7926
1 parent 2e63ab7 commit 8d4749e

File tree

5 files changed

+270
-33
lines changed

5 files changed

+270
-33
lines changed

backends/vulkan/vulkan_preprocess.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-strict
88

9+
from functools import partial
10+
911
from typing import Any, Dict, final, List
1012

1113
import executorch.backends.vulkan.utils as utils
@@ -17,7 +19,6 @@
1719
from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass
1820
from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass
1921
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
20-
2122
from executorch.backends.vulkan._passes import (
2223
insert_prepack_nodes,
2324
RemoveLocalScalarDenseOpsTransform,
@@ -41,6 +42,8 @@
4142
PreprocessResult,
4243
)
4344
from executorch.exir.backend.utils import DelegateMappingBuilder
45+
46+
from executorch.exir.memory_planning import greedy
4447
from executorch.exir.pass_base import ExportPass, PassBase
4548

4649
from executorch.exir.passes import MemoryPlanningPass, SpecPropPass
@@ -189,11 +192,12 @@ def preprocess( # noqa: C901
189192

190193
# Finally, apply dynamic shape passes and memory planning pass. These passes
191194
# must be applied only when the graph structure is finalized.
195+
greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False)
192196
program = apply_passes(
193197
program,
194198
[
195199
ConstraintBasedSymShapeEvalPass(),
196-
MemoryPlanningPass(),
200+
MemoryPlanningPass(memory_planning_algo=greedy_memory_planning),
197201
],
198202
)
199203

exir/memory_planning.py

Lines changed: 199 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import operator
1212
import typing
1313
from collections import defaultdict
14-
from dataclasses import dataclass
14+
from dataclasses import dataclass, field
1515
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
1616

1717
import torch
@@ -117,6 +117,17 @@ def storage_overlap(cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec) -> bool:
117117

118118
return has_overlap
119119

120+
@classmethod
121+
def _debug_message_from_specs(
122+
cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec
123+
) -> str:
124+
message = (
125+
f"lhs life time: {lhs_spec.lifetime}, rhs lifetime: {rhs_spec.lifetime} "
126+
)
127+
message += f"lhs: mem_id {lhs_spec.mem_id} storage: {lhs_spec.mem_offset}, {lhs_spec.allocated_memory} "
128+
message += f"rhs: mem_id {rhs_spec.mem_id} storage: {rhs_spec.mem_offset}, {rhs_spec.allocated_memory}"
129+
return message
130+
120131
def verify_storage_reuse(
121132
self, allow_lifetime_and_storage_overlap: bool = False
122133
) -> int:
@@ -159,7 +170,7 @@ def verify_storage_reuse(
159170
lhs_spec, rhs_spec
160171
):
161172
raise InternalError(
162-
f"Unexpected storage overlap: lhs {lhs_spec}, rhs {rhs_spec}"
173+
f"Unexpected storage overlap: {Verifier._debug_message_from_specs(lhs_spec, rhs_spec)}"
163174
)
164175

165176
# Check that each mem_obj_id is consistent with whether the tensors have
@@ -454,6 +465,18 @@ def update_all_tensors_lifetime(
454465
return specs
455466

456467

468+
@dataclass
469+
class AllocationSpec:
470+
"""
471+
AllocationSpec is used to represent the allocation of a tensor.
472+
"""
473+
474+
# The offset of the tensor in the shared object/pool.
475+
offset: int
476+
# TensorSpec
477+
spec: TensorSpec
478+
479+
457480
@dataclass
458481
class SharedObject:
459482
r"""
@@ -470,8 +493,15 @@ class SharedObject:
470493
offset: int
471494
# size of this shared object in bytes
472495
size: int
496+
# When the object is first created
497+
first_used_index: int
473498
# the object will be available for index (last_used_index + 1)
474499
last_used_index: int
500+
# list of allocations belong to this shared object
501+
allocations: List[AllocationSpec] = field(default_factory=list)
502+
503+
def __repr__(self) -> str:
504+
return f"SharedObject(idx={self.idx}, offset={self.offset}, size={self.size}, lifetime=[{self.first_used_index, self.last_used_index}])"
475505

476506

477507
def materialize_buffer(
@@ -489,35 +519,124 @@ def materialize_buffer(
489519
return total_size
490520

491521

492-
def _size_abs_dif(sobj: SharedObject, spec: TensorSpec) -> int:
522+
def _does_not_overlap(sobj: SharedObject, spec: TensorSpec) -> bool:
493523
r"""
494-
Calculate the absolute different between the size of a shared object and
495-
a tensor.
524+
Check if a shared object and a tensor do not overlap.
496525
"""
497-
return abs(sobj.size - spec.allocated_memory)
526+
for alloc in sobj.allocations:
527+
if not (
528+
spec.lifetime[1] < alloc.spec.lifetime[0]
529+
or spec.lifetime[0] > alloc.spec.lifetime[1]
530+
):
531+
return False
532+
return True
533+
534+
535+
def _find_max_overlapping_allocations_offset(
536+
sobj: SharedObject, spec: TensorSpec
537+
) -> int:
538+
max_offset = 0
539+
for alloc in sobj.allocations:
540+
if (
541+
spec.lifetime[1] < alloc.spec.lifetime[0]
542+
or spec.lifetime[0] > alloc.spec.lifetime[1]
543+
):
544+
continue
545+
max_offset = max(alloc.offset + alloc.spec.allocated_memory, max_offset)
546+
return max_offset
498547

499548

500549
def pick_shared_obj(
501-
shared_objects: List[SharedObject], spec: TensorSpec
550+
shared_objects: List[SharedObject],
551+
spec: TensorSpec,
552+
allow_overlapping_allocations: bool = True,
502553
) -> SharedObject:
503554
r"""
504-
Pick the available shared object with closest size to the tensor.
505-
If there are no available shared object left, create a new one.
555+
Pick the available shared object to which to assign this spec,
556+
or create a new one
557+
Algorithm details
558+
Previous: Look at every spec in chronological order. Find if previously allocated object
559+
allows it to fit in. If not, allocate a new object.
560+
New:
561+
- Sort all the specs by allocation size
562+
- Process the specs in order
563+
- If the spec's size in smaller than previously allocated buckets:
564+
- Conditions under which previously allocated bucket can be used:
565+
- Lifetime of the spec does not overlap with lifetime of the bucket.
566+
- In this case allocate spec to that bucket and expand its lifetime.
567+
- Spec is allocated at offset = 0 in this bucket.
568+
- Add this spec to allocated object's list of specs.
569+
- Lifetime of the spec overlaps with lifetime of the bucket,
570+
partially or fully (e.g. spec's lifetime subset of bucket's lifetime)
571+
- If none of the specs in the bucket overlaps with spec's lifetime.
572+
- Allocate spec to the bucket at offset = 0.
573+
- Add this spec to the bucket's list of specs.
574+
- Expand bucket's lifetime accounting for added spec's lifetime.
575+
- If one or more specs in the bucket overlaps with spec's lifetime.
576+
- Collect offsets (at which the given overlapping spec is allocated in the bucket).
577+
of all the overlapping specs, and find the max offset.
578+
- Allocate spec to the bucket at offset = max_offset + max_offset_spec_size.
579+
- Add this spec to the bucket's list of specs.
580+
- Expand bucket's lifetime accounting for added spec's lifetime.
581+
- If none of these conditions are met, allocate a new bucket.
582+
- Add spec to this bucket.
583+
- Update bucket's lifetime to that of the spec.
584+
- If the spec's size is larger than previously allocated buckets, allocate a new bucket.
585+
- Size and lifetime of this bucket is that of the spec
586+
587+
Proof of correctness:
588+
- If allocating a new bucket, it is correct.
589+
- If allocating spec to an existing bucket, whose lifetime does not overlap with any
590+
of the previously allocated specs' lifetime, then the allocation is correct.
591+
Proof of correctness by induction when adding spec to an existing bucket:
592+
- If all previous allocations in the given bucket are correct:
593+
- Then the new one being added must be correct because when the requested allocation
594+
overlaps with one or more previous allocations, we find the largest offset among
595+
all the overlapping allocations, and allocate the new spec at that offset. Hence,
596+
the allocation at such an offset, will not overlap with any previous allocations.
597+
Base case: A newly added allocation within a bucket with single allocation is correct:
598+
because a) it must fit and b) its lifetime must not overlap with object's lifetime.
599+
This holds true because of the following invariants:
600+
- Once a bucket is created, it is never resized.
601+
- All the allocations within a bucket follow this:
602+
- Span, defined by allocation's offset + size, of two allocations can only overlap,
603+
if their timelines do not overlap.
506604
"""
507-
# TODO: do better than linear scan
508605
picked = None
509606
for sobj in shared_objects:
510-
if spec.lifetime[0] > sobj.last_used_index:
511-
if picked is None or _size_abs_dif(sobj, spec) < _size_abs_dif(
512-
picked, spec
513-
):
514-
picked = sobj
515-
sobj.last_used_index = spec.lifetime[1]
516-
sobj.size = max(sobj.size, spec.allocated_memory)
607+
if _does_not_overlap(sobj, spec):
608+
assert sobj.size >= spec.allocated_memory, "Allocation specs are not sorted"
609+
picked = sobj
610+
sobj.first_used_index = min(sobj.first_used_index, spec.lifetime[0])
611+
sobj.last_used_index = max(sobj.last_used_index, spec.lifetime[1])
612+
allocation_spec = AllocationSpec(0, spec)
613+
picked.allocations.append(allocation_spec)
614+
break
615+
616+
if picked is None and allow_overlapping_allocations:
617+
for sobj in shared_objects:
618+
max_offset = _find_max_overlapping_allocations_offset(sobj, spec)
619+
if max_offset > 0:
620+
if max_offset + spec.allocated_memory <= sobj.size:
621+
picked = sobj
622+
sobj.first_used_index = min(sobj.first_used_index, spec.lifetime[0])
623+
sobj.last_used_index = max(sobj.last_used_index, spec.lifetime[1])
624+
allocation_spec = AllocationSpec(max_offset, spec)
625+
picked.allocations.append(allocation_spec)
626+
break
627+
517628
if picked is None:
518629
picked = SharedObject(
519-
len(shared_objects), -1, spec.allocated_memory, spec.lifetime[1]
630+
len(shared_objects),
631+
-1,
632+
spec.allocated_memory,
633+
spec.lifetime[0],
634+
spec.lifetime[1],
520635
)
636+
allocation_spec = AllocationSpec(0, spec)
637+
picked.allocations.append(allocation_spec)
638+
picked.first_used_index = spec.lifetime[0]
639+
picked.last_used_index = spec.lifetime[1]
521640
shared_objects.append(picked)
522641

523642
return picked
@@ -550,13 +669,50 @@ def get_node_tensor_specs(
550669
]
551670

552671

672+
# Little bit hacky to check if the graph contains
673+
# XNNPACK delegate
674+
# Why?
675+
676+
677+
def _contains_xnnpack_delegate(graph_module: torch.fx.GraphModule) -> bool:
678+
for node in graph_module.graph.nodes:
679+
if node.target == executorch_call_delegate:
680+
lowered_module = getattr(
681+
graph_module.graph.owning_module, node.args[0].target
682+
)
683+
if "xnnpack" in lowered_module.backend_id.lower():
684+
return True
685+
return False
686+
687+
553688
def greedy(
554689
graph_module: torch.fx.GraphModule,
555690
alignment: int,
556691
graph_signature: Optional[ExportGraphSignature] = None,
557692
alloc_graph_input: bool = True,
558693
alloc_graph_output: bool = True,
694+
allow_overlapping_allocations: bool = True,
559695
) -> List[int]:
696+
r"""Greedy algorithm to allocate memory for tensors in the graph.
697+
alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
698+
alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
699+
allow_overlapping_allocations: If set to true, allows for allocations that overlap
700+
in their lifetime but are at different offsets in the storage. By default true.
701+
This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
702+
allocations disabled
703+
"""
704+
# padding allocation with 64 bytes.
705+
# this requirement is really for XNNPACK backend which can read tensors
706+
# beyond the end of the tensor. This is done for performance
707+
# optimizations in XNNPACK.
708+
# While accounting for backend specific requirement is not the right choice
709+
# in backend agnostic memory planning, we do it here as it seems most appropriate.
710+
# Right now this applies to greedy only so any other
711+
# algorithm that plans memory for XNNPACK backend will
712+
# not have this.
713+
extra_padded_bytes = 0
714+
if _contains_xnnpack_delegate(graph_module):
715+
extra_padded_bytes = 64
560716
spec2obj = {}
561717
shared_objects = defaultdict(list)
562718
# Don't do assertion in collect_specs_from_nodes if we have already encountered
@@ -565,24 +721,34 @@ def greedy(
565721
# For each tensor, pick the available shared object with closest size to
566722
# the tensor. If there are no available shared object left, create a new
567723
# one.
724+
import bisect
725+
726+
sorted_specs = []
568727
for spec in collect_specs_from_nodes(
569728
graph_module.graph.nodes,
570729
graph_signature,
571730
do_assertion=do_assertion,
572731
ignore_graph_input=not alloc_graph_input,
573732
ignore_graph_output=not alloc_graph_output,
574733
):
734+
bisect.insort(sorted_specs, spec, key=lambda x: x.allocated_memory)
735+
sorted_specs.reverse()
736+
737+
for spec in sorted_specs:
575738
if spec.mem_id is None:
576739
spec.mem_id = 1
577740
spec.realign(alignment)
578-
spec2obj[spec] = pick_shared_obj(shared_objects[spec.mem_id], spec)
741+
spec2obj[spec] = pick_shared_obj(
742+
shared_objects[spec.mem_id], spec, allow_overlapping_allocations
743+
)
579744

580745
if len(shared_objects) == 0:
581746
# Cannot find any tensor in the graph that needs to be allocated.
582747
# Return [0, 0] to be consistent with default behavior of naive.
583748
total_sizes = [0, 0]
584749
else:
585750
total_sizes = [0] * (max(shared_objects.keys()) + 1)
751+
num_specs_processed = 0
586752
for mem_id in shared_objects:
587753
input_total_size = 0
588754
if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None):
@@ -594,13 +760,20 @@ def greedy(
594760
total_sizes[mem_id] = materialize_buffer(
595761
shared_objects[mem_id], input_total_size
596762
)
597-
598-
# Since we now know the number of shared objects we need and the size of
599-
# each shared object, we can assign offset in the memory buffer for each
600-
# shared object.
601-
for spec, sobj in spec2obj.items():
602-
spec.mem_obj_id = sobj.idx
603-
spec.mem_offset = sobj.offset
763+
total_sizes[mem_id] += extra_padded_bytes
764+
765+
# Since we now know the number of shared objects we need and the size of
766+
# each shared object, we can assign offset in the memory buffer for each
767+
# shared object.
768+
for sobj in shared_objects[mem_id]:
769+
for alloc in sobj.allocations:
770+
spec = alloc.spec
771+
alloc.spec.mem_obj_id = sobj.idx
772+
alloc.spec.mem_offset = sobj.offset + alloc.offset
773+
num_specs_processed += 1
774+
assert (
775+
len(spec2obj) == num_specs_processed
776+
), f"All specs should be processed but there were {len(spec2obj)} specs and processed {num_specs_processed} specs"
604777

605778
logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}")
606779
return total_sizes

0 commit comments

Comments
 (0)