1
1
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
2
3
+ # pyre-strict
4
+
3
5
import logging
4
6
import math
5
7
import unittest
12
14
from executorch .backends .cadence .aot .pass_utils import count_node
13
15
from executorch .exir import memory
14
16
from executorch .exir .dialects ._ops import ops as exir_ops
17
+ from executorch .exir .memory_planning import collect_specs_from_nodes
15
18
from executorch .exir .tests .models import MultiLayerPerceptron
16
19
17
20
18
21
class TestMemPlanningPasses (unittest .TestCase ):
19
- def test_calculate_peak_memory_pass (self ):
22
+ def test_calculate_peak_memory_pass (self ) -> None :
20
23
class PeakMemoryTestModel (torch .nn .Module ):
21
24
def __init__ (self , input_dim : int , hidden_dim : int , output_dim : int ):
22
25
super ().__init__ ()
@@ -30,7 +33,7 @@ def forward(self, x: torch.Tensor):
30
33
x = self .linear2 (x )
31
34
return x
32
35
33
- def calculate_aligned_num_bytes (num : int , alignment : int = 16 ):
36
+ def calculate_aligned_num_bytes (num : int , alignment : int = 16 ) -> int :
34
37
return math .ceil (num / alignment ) * alignment
35
38
36
39
# model 1
@@ -84,7 +87,7 @@ def calculate_aligned_num_bytes(num: int, alignment: int = 16):
84
87
) # Align data on a 16 byte boundary
85
88
self .assertEqual (peak_usage , expected_peak_usage )
86
89
87
- def test_zero_memory_pass (self ):
90
+ def test_zero_memory_pass (self ) -> None :
88
91
class ZeroMem (torch .nn .Module ):
89
92
def forward (self , x ):
90
93
return x [:, 2 ::3 , ...]
@@ -186,7 +189,7 @@ def _verify_select_nop_memory_alloc(self, node: torch.fx.Node) -> None:
186
189
f"{ spec = } { arg_spec = } " ,
187
190
)
188
191
189
- def verify_nop_memory_alloc (self , graph_module ) :
192
+ def verify_nop_memory_alloc (self , graph_module : torch . fx . GraphModule ) -> None :
190
193
for node in graph_module .graph .find_nodes (
191
194
op = "call_function" , target = torch .ops .aten ._cat_nop .out
192
195
):
@@ -202,7 +205,7 @@ def verify_nop_memory_alloc(self, graph_module):
202
205
):
203
206
self ._verify_select_nop_memory_alloc (node )
204
207
205
- def test_optimize_cat_on_placeholders (self ):
208
+ def test_optimize_cat_on_placeholders (self ) -> None :
206
209
class Cat (torch .nn .Module ):
207
210
def forward (self , x , y ):
208
211
return torch .ops .aten .cat ((x , y ))
@@ -226,7 +229,7 @@ def forward(self, x, y):
226
229
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
227
230
self .verify_nop_memory_alloc (graph_module )
228
231
229
- def test_optimize_cat_outermost (self ):
232
+ def test_optimize_cat_outermost (self ) -> None :
230
233
class OptimizeCatFeasible1 (torch .nn .Module ):
231
234
def forward (self , x , y ):
232
235
x1 = torch .add (x , 2.4 , 3.1 )
@@ -253,7 +256,7 @@ def forward(self, x, y):
253
256
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
254
257
self .verify_nop_memory_alloc (graph_module )
255
258
256
- def test_optimize_cat_non_outermost (self ):
259
+ def test_optimize_cat_non_outermost (self ) -> None :
257
260
class OptimizeCatFeasible2 (torch .nn .Module ):
258
261
def forward (self , x , y ):
259
262
x1 = torch .add (x , 2.4 , 3.1 )
@@ -280,7 +283,7 @@ def forward(self, x, y):
280
283
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
281
284
self .verify_nop_memory_alloc (graph_module )
282
285
283
- def test_no_optimize_cat_non_outermost (self ):
286
+ def test_no_optimize_cat_non_outermost (self ) -> None :
284
287
class OptimizeCatInfeasible1 (torch .nn .Module ):
285
288
def forward (self , x , y ):
286
289
x1 = torch .add (x , 2.4 , 3.1 )
@@ -306,7 +309,7 @@ def forward(self, x, y):
306
309
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
307
310
self .verify_nop_memory_alloc (graph_module )
308
311
309
- def test_no_optimize_cat_non_outermost1 (self ):
312
+ def test_no_optimize_cat_non_outermost1 (self ) -> None :
310
313
class OptimizeCatInfeasible2 (torch .nn .Module ):
311
314
def forward (self , x , y ):
312
315
x1 = torch .add (x , 2.4 , 3.1 )
@@ -333,7 +336,7 @@ def forward(self, x, y):
333
336
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
334
337
self .verify_nop_memory_alloc (graph_module )
335
338
336
- def test_optimize_cat_with_slice (self ):
339
+ def test_optimize_cat_with_slice (self ) -> None :
337
340
class OptimizeCatSliceFeasible (torch .nn .Module ):
338
341
def forward (self , x ):
339
342
x1 = torch .add (x , 2.4 , 3.1 )
@@ -362,7 +365,7 @@ def forward(self, x):
362
365
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
363
366
self .verify_nop_memory_alloc (graph_module )
364
367
365
- def test_optimize_cat_with_slice_infeasible (self ):
368
+ def test_optimize_cat_with_slice_infeasible (self ) -> None :
366
369
class OptimizeCatSliceInfeasible (torch .nn .Module ):
367
370
def forward (self , x , y ):
368
371
x1 = torch .add (x , 2.4 , 3.1 )
@@ -388,7 +391,7 @@ def forward(self, x, y):
388
391
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
389
392
self .verify_nop_memory_alloc (graph_module )
390
393
391
- def test_optimize_slice_Tensor (self ):
394
+ def test_optimize_slice_Tensor (self ) -> None :
392
395
class SliceTensor (torch .nn .Module ):
393
396
def forward (self , x , y , z ):
394
397
x1 = torch .add (x , 2.4 , 3.1 )
@@ -450,7 +453,7 @@ def forward(self, x, y, z):
450
453
)
451
454
self .verify_nop_memory_alloc (graph_module )
452
455
453
- def test_optimize_select_Tensor (self ):
456
+ def test_optimize_select_Tensor (self ) -> None :
454
457
class SelectTensor (torch .nn .Module ):
455
458
def forward (self , x , y , z ):
456
459
x1 = torch .add (x , 2.4 , 3.1 )
@@ -517,7 +520,7 @@ def forward(self, x, y, z):
517
520
518
521
# TODO: Test fails due to memory planning
519
522
@unittest .expectedFailure
520
- def test_optimize_cat_with_param (self ):
523
+ def test_optimize_cat_with_param (self ) -> None :
521
524
class CatWithPadding (torch .nn .Module ):
522
525
def __init__ (self , padding_shape ):
523
526
super ().__init__ ()
@@ -545,7 +548,7 @@ def forward(self, x, y):
545
548
self .assertEqual (count_node (graph_module , exir_ops .edge .aten .cat .default ), 1 )
546
549
self .verify_nop_memory_alloc (graph_module )
547
550
548
- def test_optimize_cat_then_slice_on_mutable_buffer (self ):
551
+ def test_optimize_cat_then_slice_on_mutable_buffer (self ) -> None :
549
552
class CatWithPadding (torch .nn .Module ):
550
553
def __init__ (self , padding_shape ):
551
554
super ().__init__ ()
@@ -570,7 +573,7 @@ def forward(self, x, y):
570
573
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
571
574
self .verify_nop_memory_alloc (graph_module )
572
575
573
- def test_optimize_cat_with_view (self ):
576
+ def test_optimize_cat_with_view (self ) -> None :
574
577
class CatViewFeasible (torch .nn .Module ):
575
578
def forward (self , x , y ):
576
579
x1 = torch .add (x , 2.4 , 3.1 )
@@ -597,7 +600,7 @@ def forward(self, x, y):
597
600
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
598
601
self .verify_nop_memory_alloc (graph_module )
599
602
600
- def test_no_optimize_cat_with_repeated_args (self ):
603
+ def test_no_optimize_cat_with_repeated_args (self ) -> None :
601
604
class CatViewInfeasible (torch .nn .Module ):
602
605
def forward (self , x ):
603
606
x1 = torch .add (x , 2.4 , 3.1 )
@@ -621,7 +624,7 @@ def forward(self, x):
621
624
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
622
625
self .verify_nop_memory_alloc (graph_module )
623
626
624
- def test_no_optimize_cat_with_placeholder (self ):
627
+ def test_no_optimize_cat_with_placeholder (self ) -> None :
625
628
class CatViewInfeasible (torch .nn .Module ):
626
629
def forward (self , x , y ):
627
630
# Repeat will be decomposed into a cat. The cat cannot be optimized
@@ -739,7 +742,7 @@ def forward(self, x) -> torch.Tensor:
739
742
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
740
743
self .verify_nop_memory_alloc (graph_module )
741
744
742
- def test_view_for_unallocated_output (self ):
745
+ def test_view_for_unallocated_output (self ) -> None :
743
746
class Model (torch .nn .Module ):
744
747
def __init__ (self , padding_shape ):
745
748
super ().__init__ ()
@@ -762,3 +765,40 @@ def forward(self, x, y):
762
765
)
763
766
self .assertEqual (count_node (graph_module , memory .view ), 1 )
764
767
self .verify_nop_memory_alloc (graph_module )
768
+
769
+ def test_start_alignment_constraints (self ) -> None :
770
+ class Model (torch .nn .Module ):
771
+ def __init__ (self ):
772
+ super ().__init__ ()
773
+
774
+ def forward (self , x : torch .Tensor , y : torch .Tensor ):
775
+ add_0 = torch .add (x , y )
776
+ add_1 = torch .add (x , add_0 )
777
+ add_2 = torch .add (add_0 , add_1 )
778
+ add_3 = torch .add (add_1 , add_2 )
779
+ return add_3
780
+
781
+ model = Model ()
782
+ inputs = (torch .randn (4 , 17 ), torch .randn (4 , 17 ))
783
+ for mem_algo in range (0 , 2 ):
784
+ graph_module = (
785
+ compiler .export_to_executorch_gen_etrecord (
786
+ model ,
787
+ inputs ,
788
+ opt_level = 1 ,
789
+ mem_algo = mem_algo ,
790
+ alloc_graph_input = False ,
791
+ alloc_graph_output = False ,
792
+ mem_alignment = 37 ,
793
+ )
794
+ .exported_program ()
795
+ .graph_module
796
+ )
797
+ # Assert that all memory allocations are aligned to 32B start address
798
+ for spec in collect_specs_from_nodes (
799
+ graph_module .graph .nodes ,
800
+ ignore_graph_input = True ,
801
+ ignore_graph_output = True ,
802
+ ):
803
+ if spec and spec .mem_offset :
804
+ self .assertEqual (spec .mem_offset % 37 , 0 )
0 commit comments