@@ -40,6 +40,12 @@ def get_size(memory_config: MemoryConfig, exir_id: int) -> int:
40
40
return memory_config .memory_sizes [exir_id - 1 ]
41
41
42
42
43
+ def get_alignment (memory_config : MemoryConfig , exir_id : int ) -> int :
44
+ # EXIR's spec.mem_id is indexed from 1..N.
45
+ assert memory_config .memory_alignments is not None
46
+ return memory_config .memory_alignments [exir_id - 1 ]
47
+
48
+
43
49
def get_aligned_offset (pre_aligned_offset : int , alignment : int ) -> int :
44
50
return int (math .ceil (pre_aligned_offset / alignment ) * alignment )
45
51
@@ -84,6 +90,10 @@ def position_based_greedy_with_hierarchy(
84
90
]
85
91
] = None ,
86
92
) -> List [int ]:
93
+ # We do not use the `alignment` parameter and instead use the per-memory alignment
94
+ # constraints from `memory_config`.
95
+ del alignment
96
+
87
97
num_memories = get_num_memories (memory_config )
88
98
bufsizes = [0 ] * num_memories
89
99
allocated_buffers : List [List [TensorSpec ]] = [[] for _ in range (num_memories )]
@@ -103,7 +113,7 @@ def overlap(spec: TensorSpec) -> Optional[TensorSpec]:
103
113
104
114
def memory_available (spec : TensorSpec ) -> bool :
105
115
return get_aligned_offset (
106
- spec .mem_offset + spec .allocated_memory , alignment
116
+ spec .mem_offset + spec .allocated_memory , get_alignment ( memory_config , spec . mem_id )
107
117
) <= get_size (memory_config , spec .mem_id )
108
118
109
119
# Iterate over all the specs in sorted order
@@ -124,7 +134,7 @@ def memory_available(spec: TensorSpec) -> bool:
124
134
spec .mem_offset = 0
125
135
while memory_available (spec ) and (overlapped := overlap (spec )):
126
136
spec .mem_offset = get_aligned_offset (
127
- overlapped .mem_offset + overlapped .allocated_memory , alignment
137
+ overlapped .mem_offset + overlapped .allocated_memory , get_alignment ( memory_config , spec . mem_id )
128
138
)
129
139
if memory_available (spec ):
130
140
allocated_buffers [spec .mem_id ].append (spec )
@@ -172,6 +182,10 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
172
182
]
173
183
] = None ,
174
184
) -> List [int ]:
185
+ # We do not use the `alignment` parameter and instead use the per-memory alignment
186
+ # constraints from `memory_config`.
187
+ del alignment
188
+
175
189
num_memories = get_num_memories (memory_config )
176
190
bufsizes = [0 ] * num_memories
177
191
allocated_buffers = [[] for _ in range (num_memories )]
@@ -213,13 +227,13 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
213
227
prev_offset = max (
214
228
get_aligned_offset (
215
229
allocated_spec .mem_offset + allocated_spec .allocated_memory ,
216
- alignment ,
230
+ get_alignment ( memory_config , spec . mem_id ) ,
217
231
),
218
232
prev_offset ,
219
233
)
220
234
if spec .mem_offset is None :
221
235
if get_aligned_offset (
222
- prev_offset + spec .allocated_memory , alignment
236
+ prev_offset + spec .allocated_memory , get_alignment ( memory_config , spec . mem_id )
223
237
) > get_size (memory_config , spec .mem_id ):
224
238
continue
225
239
else :
@@ -439,7 +453,6 @@ def __init__(
439
453
]
440
454
]
441
455
] = None ,
442
- mem_alignment : int = 1 ,
443
456
) -> None :
444
457
self ._init_mem_algos ()
445
458
@@ -450,9 +463,6 @@ def __init__(
450
463
self .alloc_graph_output = alloc_graph_output
451
464
self .additional_constraint_gen_passes = additional_constraint_gen_passes
452
465
453
- assert mem_alignment > 0 , "mem_alignment must be positive"
454
- self .mem_alignment = mem_alignment
455
-
456
466
def _init_mem_algos (self ) -> None :
457
467
self .available_mem_algos = [
458
468
position_based_greedy_with_hierarchy ,
@@ -489,7 +499,6 @@ def run(
489
499
allow_lifetime_and_storage_overlap = (self .opt_level >= 2 ),
490
500
alloc_graph_input = self .alloc_graph_input ,
491
501
alloc_graph_output = self .alloc_graph_output ,
492
- alignment = self .mem_alignment ,
493
502
)
494
503
mem_planning .run (graph_module , graph_signature )
495
504
0 commit comments