9
9
import collections
10
10
import itertools
11
11
import logging
12
+ import math
12
13
import typing
13
14
from functools import partial
14
15
from typing import Iterable , List , Optional , Tuple
@@ -39,6 +40,12 @@ def get_size(memory_config: MemoryConfig, exir_id: int) -> int:
39
40
return memory_config .memory_sizes [exir_id - 1 ]
40
41
41
42
43
+ def get_aligned_offset (pre_aligned_offset : int , alignment : int ) -> int :
44
+ if alignment == 0 :
45
+ return pre_aligned_offset
46
+ return int (math .ceil (pre_aligned_offset / alignment ) * alignment )
47
+
48
+
42
49
def collect_specs_from_graph_module (
43
50
graph_module : torch .fx .GraphModule ,
44
51
alloc_graph_input : bool ,
@@ -95,7 +102,7 @@ def overlap(spec: TensorSpec) -> Optional[TensorSpec]:
95
102
return None
96
103
97
104
def memory_available (spec : TensorSpec ) -> bool :
98
- return spec .mem_offset + spec .allocated_memory <= get_size (
105
+ return get_aligned_offset ( spec .mem_offset + spec .allocated_memory , alignment ) <= get_size (
99
106
memory_config , spec .mem_id
100
107
)
101
108
@@ -116,7 +123,7 @@ def memory_available(spec: TensorSpec) -> bool:
116
123
continue
117
124
spec .mem_offset = 0
118
125
while memory_available (spec ) and (overlapped := overlap (spec )):
119
- spec .mem_offset = overlapped .mem_offset + overlapped .allocated_memory
126
+ spec .mem_offset = get_aligned_offset ( overlapped .mem_offset + overlapped .allocated_memory , alignment )
120
127
if memory_available (spec ):
121
128
allocated_buffers [spec .mem_id ].append (spec )
122
129
bufsizes [spec .mem_id ] = max (
@@ -202,11 +209,11 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
202
209
# calculation of gap incorrect. Moving it out will make the algorithm degenerate
203
210
# to the naive one, reusing 0 tensor. The paper may have a typo here.
204
211
prev_offset = max (
205
- allocated_spec .mem_offset + allocated_spec .allocated_memory ,
212
+ get_aligned_offset ( allocated_spec .mem_offset + allocated_spec .allocated_memory , alignment ) ,
206
213
prev_offset ,
207
214
)
208
215
if spec .mem_offset is None :
209
- if prev_offset + spec .allocated_memory > get_size (
216
+ if get_aligned_offset ( prev_offset + spec .allocated_memory , alignment ) > get_size (
210
217
memory_config , spec .mem_id
211
218
):
212
219
continue
@@ -423,6 +430,7 @@ def __init__(
423
430
]
424
431
]
425
432
] = None ,
433
+ mem_alignment : int = 0 ,
426
434
) -> None :
427
435
self ._init_mem_algos ()
428
436
@@ -432,6 +440,7 @@ def __init__(
432
440
self .alloc_graph_input = alloc_graph_input
433
441
self .alloc_graph_output = alloc_graph_output
434
442
self .additional_constraint_gen_passes = additional_constraint_gen_passes
443
+ self .mem_alignment = mem_alignment
435
444
436
445
def _init_mem_algos (self ) -> None :
437
446
self .available_mem_algos = [
@@ -459,6 +468,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
459
468
allow_lifetime_and_storage_overlap = (self .opt_level >= 2 ),
460
469
alloc_graph_input = self .alloc_graph_input ,
461
470
alloc_graph_output = self .alloc_graph_output ,
471
+ alignment = self .mem_alignment ,
462
472
)
463
473
mem_planning (graph_module )
464
474
0 commit comments