Skip to content

Commit 3975a75

Browse files
committed
Update on "[ET][Memory planning] Improve greedy memory planning."
This diff replaces the old greedy algorithm. Older algorithm resulted in 35% worse compared to theoretical optimum. THis matter for long context even more since additional overhead can be few hundred MB. For example the theorical optimial for llama3_2 8B, 4-bit quantized modelw ith context length of 2k needs about 1G of memory. This theoretcial max can be observed by looking at the peaks in memory profile. Current agorithm resulted in about 1.6GB of planned memory. New algorithm reduce that to about 1.1G. Differential Revision: [D68448332](https://our.internmc.facebook.com/intern/diff/D68448332/) [ghstack-poisoned]
1 parent 816efe9 commit 3975a75

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

backends/vulkan/vulkan_preprocess.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-strict
88

9+
from functools import partial
10+
911
from typing import Any, Dict, final, List
1012

1113
import executorch.backends.vulkan.utils as utils
@@ -18,6 +20,9 @@
1820
from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass
1921
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
2022

23+
from executorch.exir.memory_planning import (
24+
greedy,
25+
)
2126
from executorch.backends.vulkan._passes import (
2227
insert_prepack_nodes,
2328
RemoveLocalScalarDenseOpsTransform,
@@ -189,11 +194,12 @@ def preprocess( # noqa: C901
189194

190195
# Finally, apply dynamic shape passes and memory planning pass. These passes
191196
# must be applied only when the graph structure is finalized.
197+
greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False)
192198
program = apply_passes(
193199
program,
194200
[
195201
ConstraintBasedSymShapeEvalPass(),
196-
MemoryPlanningPass(),
202+
MemoryPlanningPass(memory_planning_algo=greedy_memory_planning),
197203
],
198204
)
199205

exir/memory_planning.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,9 @@ def _find_max_overlapping_allocations_offset(
547547

548548

549549
def pick_shared_obj(
550-
shared_objects: List[SharedObject], spec: TensorSpec
550+
shared_objects: List[SharedObject],
551+
spec: TensorSpec,
552+
allow_overlapping_allocations: bool = True,
551553
) -> SharedObject:
552554
r"""
553555
Pick the available shared object to which to assign this spec,
@@ -611,7 +613,7 @@ def pick_shared_obj(
611613
picked.allocations.append(allocation_spec)
612614
break
613615

614-
if picked is None:
616+
if picked is None and allow_overlapping_allocations:
615617
for sobj in shared_objects:
616618
max_offset = _find_max_overlapping_allocations_offset(sobj, spec)
617619
if max_offset > 0:
@@ -673,7 +675,16 @@ def greedy(
673675
graph_signature: Optional[ExportGraphSignature] = None,
674676
alloc_graph_input: bool = True,
675677
alloc_graph_output: bool = True,
678+
allow_overlapping_allocations: bool = True,
676679
) -> List[int]:
680+
r"""Greedy algorithm to allocate memory for tensors in the graph.
681+
alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
682+
alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
683+
allow_overlapping_allocations: If set to true, allows for allocations that overlap
684+
in their lifetime but are at different offsets in the storage. By default true.
685+
This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
686+
allocations disabled
687+
"""
677688
spec2obj = {}
678689
shared_objects = defaultdict(list)
679690
# Don't do assertion in collect_specs_from_nodes if we have already encountered
@@ -699,7 +710,9 @@ def greedy(
699710
if spec.mem_id is None:
700711
spec.mem_id = 1
701712
spec.realign(alignment)
702-
spec2obj[spec] = pick_shared_obj(shared_objects[spec.mem_id], spec)
713+
spec2obj[spec] = pick_shared_obj(
714+
shared_objects[spec.mem_id], spec, allow_overlapping_allocations
715+
)
703716

704717
if len(shared_objects) == 0:
705718
# Cannot find any tensor in the graph that needs to be allocated.

0 commit comments

Comments
 (0)