1
1
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
2
3
+ import logging
3
4
import math
4
5
import unittest
6
+ from typing import cast
5
7
6
8
import executorch .backends .cadence .aot .ops_registrations # noqa
7
9
import torch
@@ -110,7 +112,89 @@ def forward(self, x):
110
112
111
113
112
114
class TestMemTransform (unittest .TestCase ):
113
- def test_optimize_cat (self ):
115
+ def verify_cat_nop_memory_alloc (self , cat_node : torch .fx .Node ) -> None :
116
+ spec = cat_node .meta .get ("spec" , None )
117
+ self .assertIsNotNone (spec )
118
+ dim : int = cast (int , cat_node .args [1 ]) if len (cat_node .args ) > 1 else 0
119
+ cat_outer_size = math .prod (spec .shape [:dim ])
120
+ self .assertEqual (
121
+ cat_outer_size ,
122
+ 1 ,
123
+ f"{ cat_node = } has wrong outer size: { cat_outer_size = } , expected 1." ,
124
+ )
125
+ inner_dim_elements = math .prod (spec .shape [dim + 1 :]) * spec .dtype .itemsize
126
+ dim_offset = 0
127
+ for arg in cast (list [torch .fx .Node ], cat_node .args [0 ]):
128
+ arg_spec = arg .meta .get ("spec" , None )
129
+ self .assertEqual (arg_spec .mem_id , spec .mem_id )
130
+ self .assertEqual (
131
+ arg_spec .mem_offset ,
132
+ spec .mem_offset + dim_offset * inner_dim_elements ,
133
+ f"{ arg = } for node { cat_node = } has wrong memory offset: { arg_spec .mem_offset = } { dim_offset = } for cat on { dim = } , but output has { spec .mem_offset = } " ,
134
+ )
135
+ dim_offset += arg_spec .shape [dim ]
136
+
137
+ def verify_slice_nop_memory_alloc (self , slice_node : torch .fx .Node ) -> None :
138
+ spec = slice_node .meta .get ("spec" , None )
139
+ self .assertIsNotNone (spec )
140
+ dim : int = cast (int , slice_node .args [1 ]) if len (slice_node .args ) > 1 else 0
141
+ cat_outer_size = math .prod (spec .shape [:dim ])
142
+ self .assertEqual (
143
+ cat_outer_size ,
144
+ 1 ,
145
+ f"{ slice_node = } has wrong outer size: { cat_outer_size = } , expected 1." ,
146
+ )
147
+ inner_dim_elements = math .prod (spec .shape [dim + 1 :]) * spec .dtype .itemsize
148
+ start : int = (
149
+ cast (int , slice_node .args [2 ])
150
+ if (len (slice_node .args ) > 2 and slice_node .args [2 ] is not None )
151
+ else 0
152
+ )
153
+ arg = cast (torch .fx .Node , slice_node .args [0 ])
154
+ arg_spec = arg .meta .get ("spec" , None )
155
+ self .assertEqual (arg_spec .mem_id , spec .mem_id )
156
+ self .assertEqual (
157
+ spec .mem_offset ,
158
+ arg_spec .mem_offset + start * inner_dim_elements ,
159
+ f"{ arg = } for node { slice_node = } has wrong memory offset: { arg_spec .mem_offset = } { start = } for cat on { dim = } , but output has { spec .mem_offset = } " ,
160
+ )
161
+
162
+ def verify_nop_memory_alloc (self , graph_module ):
163
+ for cat_node in graph_module .graph .find_nodes (
164
+ op = "call_function" , target = torch .ops .aten ._cat_nop .out
165
+ ):
166
+ self .verify_cat_nop_memory_alloc (cat_node )
167
+
168
+ for slice_node in graph_module .graph .find_nodes (
169
+ op = "call_function" , target = torch .ops .aten ._slice_copy_nop .Tensor_out
170
+ ):
171
+ self .verify_slice_nop_memory_alloc (slice_node )
172
+
173
+ def test_optimize_cat_on_placeholders (self ):
174
+ class Cat (torch .nn .Module ):
175
+ def forward (self , x , y ):
176
+ return torch .ops .aten .cat ((x , y ))
177
+
178
+ x = torch .ones (3 , 6 )
179
+ y = torch .ones (2 , 6 )
180
+ # Optimizing cat ops is only at opt_level 2+, and requires the memory planning
181
+ # pass to run:
182
+ graph_module = (
183
+ compiler .export_to_executorch_gen_etrecord (
184
+ Cat (), (x , y ), opt_level = 2 , mem_algo = 1
185
+ )
186
+ .exported_program ()
187
+ .graph_module
188
+ )
189
+ logging .info (f"graph_module: { graph_module .print_readable (print_output = False )} " )
190
+ graph_module .graph .eliminate_dead_code ()
191
+ # Assert that cat op is optimized away
192
+ self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
193
+ # Assert that cat op is replaced by its nop version post optimization
194
+ self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
195
+ self .verify_nop_memory_alloc (graph_module )
196
+
197
+ def test_optimize_cat_outermost (self ):
114
198
class OptimizeCatFeasible1 (torch .nn .Module ):
115
199
def forward (self , x , y ):
116
200
x1 = torch .add (x , 2.4 , 3.1 )
@@ -135,7 +219,9 @@ def forward(self, x, y):
135
219
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
136
220
# Assert that cat op is replaced by its nop version post optimization
137
221
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
222
+ self .verify_nop_memory_alloc (graph_module )
138
223
224
+ def test_optimize_cat_non_outermost (self ):
139
225
class OptimizeCatFeasible2 (torch .nn .Module ):
140
226
def forward (self , x , y ):
141
227
x1 = torch .add (x , 2.4 , 3.1 )
@@ -160,7 +246,9 @@ def forward(self, x, y):
160
246
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
161
247
# Assert that cat op is replaced by its nop version post optimization
162
248
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
249
+ self .verify_nop_memory_alloc (graph_module )
163
250
251
+ def test_no_optimize_cat_non_outermost (self ):
164
252
class OptimizeCatInfeasible1 (torch .nn .Module ):
165
253
def forward (self , x , y ):
166
254
x1 = torch .add (x , 2.4 , 3.1 )
@@ -184,7 +272,9 @@ def forward(self, x, y):
184
272
# Assert that cat op is not optimized away, since the concat is not
185
273
# along the outermost dim
186
274
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
275
+ self .verify_nop_memory_alloc (graph_module )
187
276
277
+ def test_no_optimize_cat_non_outermost1 (self ):
188
278
class OptimizeCatInfeasible2 (torch .nn .Module ):
189
279
def forward (self , x , y ):
190
280
x1 = torch .add (x , 2.4 , 3.1 )
@@ -209,6 +299,7 @@ def forward(self, x, y):
209
299
# offsets are not multiple of 8 bytes, and the cat is not the output
210
300
# of the graph.
211
301
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
302
+ self .verify_nop_memory_alloc (graph_module )
212
303
213
304
def test_optimize_cat_with_slice (self ):
214
305
class OptimizeCatSliceFeasible (torch .nn .Module ):
@@ -237,6 +328,7 @@ def forward(self, x):
237
328
graph_module .graph .eliminate_dead_code ()
238
329
# Assert that cat op is optimized away
239
330
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
331
+ self .verify_nop_memory_alloc (graph_module )
240
332
241
333
def test_optimize_cat_with_slice_infeasible (self ):
242
334
class OptimizeCatSliceInfeasible (torch .nn .Module ):
@@ -262,6 +354,7 @@ def forward(self, x, y):
262
354
graph_module .graph .eliminate_dead_code ()
263
355
# Assert that cat op is not optimized away
264
356
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
357
+ self .verify_nop_memory_alloc (graph_module )
265
358
266
359
def test_optimize_slice_Tensor (self ):
267
360
class SliceTensor (torch .nn .Module ):
@@ -323,6 +416,7 @@ def forward(self, x, y, z):
323
416
self .assertEqual (
324
417
count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 3
325
418
)
419
+ self .verify_nop_memory_alloc (graph_module )
326
420
327
421
def test_optimize_select_Tensor (self ):
328
422
class SelectTensor (torch .nn .Module ):
@@ -387,6 +481,7 @@ def forward(self, x, y, z):
387
481
self .assertEqual (
388
482
count_node (graph_module , torch .ops .aten ._select_copy_nop .int_out ), 3
389
483
)
484
+ self .verify_nop_memory_alloc (graph_module )
390
485
391
486
# TODO: Test fails due to memory planning
392
487
@unittest .expectedFailure
@@ -416,6 +511,32 @@ def forward(self, x, y):
416
511
graph_module .graph .eliminate_dead_code ()
417
512
# Assert that cat op is not optimized away
418
513
self .assertEqual (count_node (graph_module , exir_ops .edge .aten .cat .default ), 1 )
514
+ self .verify_nop_memory_alloc (graph_module )
515
+
516
+ def test_optimize_cat_then_slice_on_mutable_buffer (self ):
517
+ class CatWithPadding (torch .nn .Module ):
518
+ def __init__ (self , padding_shape ):
519
+ super ().__init__ ()
520
+ zeros = torch .zeros (padding_shape )
521
+ self .register_buffer ("padding" , zeros )
522
+
523
+ def forward (self , x , y ):
524
+ x = x .view (3 , 5 )
525
+ cat = torch .ops .aten .cat ((x , self .padding .clone ()))
526
+ slice_copy = torch .ops .aten .slice (cat , dim = 0 , start = x .shape [0 ])
527
+ self .padding .copy_ (slice_copy )
528
+ return cat .view (- 1 ) + y
529
+
530
+ x = torch .ones (15 )
531
+ y = torch .ones (1 )
532
+ et_prog_manager = compiler .export_to_executorch_gen_etrecord (
533
+ CatWithPadding ((1 , 5 )), (x , y ), opt_level = 3
534
+ )
535
+ graph_module = et_prog_manager .exported_program ().graph_module
536
+ logging .info (f"graph_module: { graph_module .print_readable (print_output = False )} " )
537
+ self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
538
+ self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
539
+ self .verify_nop_memory_alloc (graph_module )
419
540
420
541
def test_optimize_cat_with_view (self ):
421
542
class CatViewFeasible (torch .nn .Module ):
@@ -442,6 +563,7 @@ def forward(self, x, y):
442
563
# Assert that cat op is optimized away
443
564
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
444
565
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
566
+ self .verify_nop_memory_alloc (graph_module )
445
567
446
568
def test_no_optimize_cat_with_repeated_args (self ):
447
569
class CatViewInfeasible (torch .nn .Module ):
@@ -465,6 +587,7 @@ def forward(self, x):
465
587
# Assert that cat op is not optimized away
466
588
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
467
589
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
590
+ self .verify_nop_memory_alloc (graph_module )
468
591
469
592
def test_no_optimize_cat_with_placeholder (self ):
470
593
class CatViewInfeasible (torch .nn .Module ):
@@ -492,6 +615,7 @@ def forward(self, x, y):
492
615
# Assert that cat op is not optimized away
493
616
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
494
617
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
618
+ self .verify_nop_memory_alloc (graph_module )
495
619
496
620
def test_no_optimize_cat (self ) -> None :
497
621
class Model (torch .nn .Module ):
@@ -522,6 +646,7 @@ def forward(self, x) -> torch.Tensor:
522
646
count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 2
523
647
)
524
648
self .assertEqual (count_node (graph_module , memory .view ), 2 )
649
+ self .verify_nop_memory_alloc (graph_module )
525
650
526
651
def test_optimize_slice_copy (self ) -> None :
527
652
class Model (torch .nn .Module ):
@@ -553,6 +678,7 @@ def forward(self, x) -> torch.Tensor:
553
678
count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 0
554
679
)
555
680
self .assertEqual (count_node (graph_module , memory .view ), 2 )
681
+ self .verify_nop_memory_alloc (graph_module )
556
682
557
683
def test_cat_then_cat (self ) -> None :
558
684
class Model (torch .nn .Module ):
@@ -579,6 +705,7 @@ def forward(self, x) -> torch.Tensor:
579
705
graph_module .print_readable ()
580
706
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 2 )
581
707
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
708
+ self .verify_nop_memory_alloc (graph_module )
582
709
583
710
def test_view_for_unallocated_output (self ):
584
711
class Model (torch .nn .Module ):
@@ -602,3 +729,4 @@ def forward(self, x, y):
602
729
.graph_module
603
730
)
604
731
self .assertEqual (count_node (graph_module , memory .view ), 1 )
732
+ self .verify_nop_memory_alloc (graph_module )
0 commit comments