@@ -40,6 +40,12 @@ def get_size(memory_config: MemoryConfig, exir_id: int) -> int:
40
40
return memory_config .memory_sizes [exir_id - 1 ]
41
41
42
42
43
+ def get_alignment (memory_config : MemoryConfig , exir_id : int ) -> int :
44
+ # EXIR's spec.mem_id is indexed from 1..N.
45
+ assert memory_config .memory_alignments is not None
46
+ return memory_config .memory_alignments [exir_id - 1 ]
47
+
48
+
43
49
def get_aligned_offset (pre_aligned_offset : int , alignment : int ) -> int :
44
50
return int (math .ceil (pre_aligned_offset / alignment ) * alignment )
45
51
@@ -84,6 +90,10 @@ def position_based_greedy_with_hierarchy(
84
90
]
85
91
] = None ,
86
92
) -> List [int ]:
93
+ # We do not use the `alignment` parameter and instead use the per-memory alignment
94
+ # constraints from `memory_config`.
95
+ del alignment
96
+
87
97
num_memories = get_num_memories (memory_config )
88
98
bufsizes = [0 ] * num_memories
89
99
allocated_buffers : List [List [TensorSpec ]] = [[] for _ in range (num_memories )]
@@ -103,7 +113,8 @@ def overlap(spec: TensorSpec) -> Optional[TensorSpec]:
103
113
104
114
def memory_available (spec : TensorSpec ) -> bool :
105
115
return get_aligned_offset (
106
- spec .mem_offset + spec .allocated_memory , alignment
116
+ spec .mem_offset + spec .allocated_memory ,
117
+ get_alignment (memory_config , spec .mem_id ),
107
118
) <= get_size (memory_config , spec .mem_id )
108
119
109
120
# Iterate over all the specs in sorted order
@@ -124,7 +135,8 @@ def memory_available(spec: TensorSpec) -> bool:
124
135
spec .mem_offset = 0
125
136
while memory_available (spec ) and (overlapped := overlap (spec )):
126
137
spec .mem_offset = get_aligned_offset (
127
- overlapped .mem_offset + overlapped .allocated_memory , alignment
138
+ overlapped .mem_offset + overlapped .allocated_memory ,
139
+ get_alignment (memory_config , spec .mem_id ),
128
140
)
129
141
if memory_available (spec ):
130
142
allocated_buffers [spec .mem_id ].append (spec )
@@ -172,6 +184,10 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
172
184
]
173
185
] = None ,
174
186
) -> List [int ]:
187
+ # We do not use the `alignment` parameter and instead use the per-memory alignment
188
+ # constraints from `memory_config`.
189
+ del alignment
190
+
175
191
num_memories = get_num_memories (memory_config )
176
192
bufsizes = [0 ] * num_memories
177
193
allocated_buffers = [[] for _ in range (num_memories )]
@@ -213,13 +229,14 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
213
229
prev_offset = max (
214
230
get_aligned_offset (
215
231
allocated_spec .mem_offset + allocated_spec .allocated_memory ,
216
- alignment ,
232
+ get_alignment ( memory_config , spec . mem_id ) ,
217
233
),
218
234
prev_offset ,
219
235
)
220
236
if spec .mem_offset is None :
221
237
if get_aligned_offset (
222
- prev_offset + spec .allocated_memory , alignment
238
+ prev_offset + spec .allocated_memory ,
239
+ get_alignment (memory_config , spec .mem_id ),
223
240
) > get_size (memory_config , spec .mem_id ):
224
241
continue
225
242
else :
@@ -439,7 +456,6 @@ def __init__(
439
456
]
440
457
]
441
458
] = None ,
442
- mem_alignment : int = 1 ,
443
459
) -> None :
444
460
self ._init_mem_algos ()
445
461
@@ -450,9 +466,6 @@ def __init__(
450
466
self .alloc_graph_output = alloc_graph_output
451
467
self .additional_constraint_gen_passes = additional_constraint_gen_passes
452
468
453
- assert mem_alignment > 0 , "mem_alignment must be positive"
454
- self .mem_alignment = mem_alignment
455
-
456
469
def _init_mem_algos (self ) -> None :
457
470
self .available_mem_algos = [
458
471
position_based_greedy_with_hierarchy ,
@@ -489,7 +502,6 @@ def run(
489
502
allow_lifetime_and_storage_overlap = (self .opt_level >= 2 ),
490
503
alloc_graph_input = self .alloc_graph_input ,
491
504
alloc_graph_output = self .alloc_graph_output ,
492
- alignment = self .mem_alignment ,
493
505
)
494
506
mem_planning .run (graph_module , graph_signature )
495
507
0 commit comments