1
1
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
2
3
+ import logging
3
4
import math
4
5
import unittest
6
+ from typing import cast
5
7
6
8
import executorch .backends .cadence .aot .ops_registrations # noqa
7
9
import torch
@@ -110,7 +112,119 @@ def forward(self, x):
110
112
111
113
112
114
class TestMemTransform (unittest .TestCase ):
113
- def test_optimize_cat (self ):
115
+ def verify_cat_nop_memory_alloc (self , cat_node : torch .fx .Node ) -> None :
116
+ spec = cat_node .meta .get ("spec" , None )
117
+ self .assertIsNotNone (spec )
118
+ dim : int = cast (int , cat_node .args [1 ]) if len (cat_node .args ) > 1 else 0
119
+ cat_outer_size = math .prod (spec .shape [:dim ])
120
+ self .assertEqual (
121
+ cat_outer_size ,
122
+ 1 ,
123
+ f"{ cat_node = } has wrong outer size: { cat_outer_size = } , expected 1." ,
124
+ )
125
+ inner_dim_elements = math .prod (spec .shape [dim + 1 :]) * spec .dtype .itemsize
126
+ dim_offset = 0
127
+ for arg in cast (list [torch .fx .Node ], cat_node .args [0 ]):
128
+ arg_spec = arg .meta .get ("spec" , None )
129
+ self .assertEqual (arg_spec .mem_id , spec .mem_id )
130
+ self .assertEqual (
131
+ arg_spec .mem_offset ,
132
+ spec .mem_offset + dim_offset * inner_dim_elements ,
133
+ f"{ arg = } for node { cat_node = } has wrong memory offset: { arg_spec .mem_offset = } { dim_offset = } for cat on { dim = } , but output has { spec .mem_offset = } " ,
134
+ )
135
+ dim_offset += arg_spec .shape [dim ]
136
+
137
+ def verify_slice_nop_memory_alloc (self , slice_node : torch .fx .Node ) -> None :
138
+ spec = slice_node .meta .get ("spec" , None )
139
+ self .assertIsNotNone (spec )
140
+ dim : int = cast (int , slice_node .args [1 ]) if len (slice_node .args ) > 1 else 0
141
+ cat_outer_size = math .prod (spec .shape [:dim ])
142
+ self .assertEqual (
143
+ cat_outer_size ,
144
+ 1 ,
145
+ f"{ slice_node = } has wrong outer size: { cat_outer_size = } , expected 1." ,
146
+ )
147
+ inner_dim_elements = math .prod (spec .shape [dim + 1 :]) * spec .dtype .itemsize
148
+ start : int = (
149
+ cast (int , slice_node .args [2 ])
150
+ if (len (slice_node .args ) > 2 and slice_node .args [2 ] is not None )
151
+ else 0
152
+ )
153
+ arg = cast (torch .fx .Node , slice_node .args [0 ])
154
+ arg_spec = arg .meta .get ("spec" , None )
155
+ self .assertEqual (arg_spec .mem_id , spec .mem_id )
156
+ self .assertEqual (
157
+ spec .mem_offset ,
158
+ arg_spec .mem_offset + start * inner_dim_elements ,
159
+ f"{ arg = } for node { slice_node = } has wrong memory offset: { arg_spec .mem_offset = } { start = } for cat on { dim = } , but output has { spec .mem_offset = } " ,
160
+ )
161
+
162
+ def verify_select_nop_memory_alloc (self , select_node : torch .fx .Node ) -> None :
163
+ spec = select_node .meta .get ("spec" , None )
164
+ self .assertIsNotNone (spec )
165
+ dim : int = cast (int , select_node .args [1 ]) if len (select_node .args ) > 1 else 0
166
+ cat_outer_size = math .prod (spec .shape [:dim ])
167
+ self .assertEqual (
168
+ cat_outer_size ,
169
+ 1 ,
170
+ f"{ select_node = } has wrong outer size: { cat_outer_size = } , expected 1." ,
171
+ )
172
+ inner_dim_elements = math .prod (spec .shape [dim + 1 :]) * spec .dtype .itemsize
173
+ index : int = (
174
+ cast (int , select_node .args [2 ])
175
+ if (len (select_node .args ) > 2 and select_node .args [2 ] is not None )
176
+ else 0
177
+ )
178
+ arg = cast (torch .fx .Node , select_node .args [0 ])
179
+ arg_spec = arg .meta .get ("spec" , None )
180
+ self .assertEqual (arg_spec .mem_id , spec .mem_id )
181
+ self .assertEqual (
182
+ spec .mem_offset ,
183
+ arg_spec .mem_offset + index * inner_dim_elements ,
184
+ f"{ arg = } for node { select_node = } has wrong memory offset: { arg_spec .mem_offset = } { start = } for cat on { dim = } , but output has { spec .mem_offset = } " ,
185
+ )
186
+
187
+ def verify_nop_memory_alloc (self , graph_module ):
188
+ for cat_node in graph_module .graph .find_nodes (
189
+ op = "call_function" , target = torch .ops .aten ._cat_nop .out
190
+ ):
191
+ self .verify_cat_nop_memory_alloc (cat_node )
192
+
193
+ for slice_node in graph_module .graph .find_nodes (
194
+ op = "call_function" , target = torch .ops .aten ._slice_copy_nop .Tensor_out
195
+ ):
196
+ self .verify_slice_nop_memory_alloc (slice_node )
197
+
198
+ for select_node in graph_module .graph .find_nodes (
199
+ op = "call_function" , target = torch .ops .aten ._select_copy_nop .Tensor_out
200
+ ):
201
+ self .verify_select_nop_memory_alloc (slice_node )
202
+
203
+ def test_optimize_cat_on_placeholders (self ):
204
+ class Cat (torch .nn .Module ):
205
+ def forward (self , x , y ):
206
+ return torch .ops .aten .cat ((x , y ))
207
+
208
+ x = torch .ones (3 , 6 )
209
+ y = torch .ones (2 , 6 )
210
+ # Optimizing cat ops is only at opt_level 2+, and requires the memory planning
211
+ # pass to run:
212
+ graph_module = (
213
+ compiler .export_to_executorch_gen_etrecord (
214
+ Cat (), (x , y ), opt_level = 2 , mem_algo = 1
215
+ )
216
+ .exported_program ()
217
+ .graph_module
218
+ )
219
+ logging .info (f"graph_module: { graph_module .print_readable (print_output = False )} " )
220
+ graph_module .graph .eliminate_dead_code ()
221
+ # Assert that cat op is optimized away
222
+ self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
223
+ # Assert that cat op is replaced by its nop version post optimization
224
+ self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
225
+ self .verify_nop_memory_alloc (graph_module )
226
+
227
+ def test_optimize_cat_outermost (self ):
114
228
class OptimizeCatFeasible1 (torch .nn .Module ):
115
229
def forward (self , x , y ):
116
230
x1 = torch .add (x , 2.4 , 3.1 )
@@ -135,7 +249,9 @@ def forward(self, x, y):
135
249
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
136
250
# Assert that cat op is replaced by its nop version post optimization
137
251
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
252
+ self .verify_nop_memory_alloc (graph_module )
138
253
254
+ def test_optimize_cat_non_outermost (self ):
139
255
class OptimizeCatFeasible2 (torch .nn .Module ):
140
256
def forward (self , x , y ):
141
257
x1 = torch .add (x , 2.4 , 3.1 )
@@ -160,7 +276,9 @@ def forward(self, x, y):
160
276
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
161
277
# Assert that cat op is replaced by its nop version post optimization
162
278
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
279
+ self .verify_nop_memory_alloc (graph_module )
163
280
281
+ def test_no_optimize_cat_non_outermost (self ):
164
282
class OptimizeCatInfeasible1 (torch .nn .Module ):
165
283
def forward (self , x , y ):
166
284
x1 = torch .add (x , 2.4 , 3.1 )
@@ -184,7 +302,9 @@ def forward(self, x, y):
184
302
# Assert that cat op is not optimized away, since the concat is not
185
303
# along the outermost dim
186
304
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
305
+ self .verify_nop_memory_alloc (graph_module )
187
306
307
+ def test_no_optimize_cat_non_outermost1 (self ):
188
308
class OptimizeCatInfeasible2 (torch .nn .Module ):
189
309
def forward (self , x , y ):
190
310
x1 = torch .add (x , 2.4 , 3.1 )
@@ -209,6 +329,7 @@ def forward(self, x, y):
209
329
# offsets are not multiple of 8 bytes, and the cat is not the output
210
330
# of the graph.
211
331
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
332
+ self .verify_nop_memory_alloc (graph_module )
212
333
213
334
def test_optimize_cat_with_slice (self ):
214
335
class OptimizeCatSliceFeasible (torch .nn .Module ):
@@ -237,6 +358,7 @@ def forward(self, x):
237
358
graph_module .graph .eliminate_dead_code ()
238
359
# Assert that cat op is optimized away
239
360
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
361
+ self .verify_nop_memory_alloc (graph_module )
240
362
241
363
def test_optimize_cat_with_slice_infeasible (self ):
242
364
class OptimizeCatSliceInfeasible (torch .nn .Module ):
@@ -262,6 +384,7 @@ def forward(self, x, y):
262
384
graph_module .graph .eliminate_dead_code ()
263
385
# Assert that cat op is not optimized away
264
386
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
387
+ self .verify_nop_memory_alloc (graph_module )
265
388
266
389
def test_optimize_slice_Tensor (self ):
267
390
class SliceTensor (torch .nn .Module ):
@@ -323,6 +446,7 @@ def forward(self, x, y, z):
323
446
self .assertEqual (
324
447
count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 3
325
448
)
449
+ self .verify_nop_memory_alloc (graph_module )
326
450
327
451
def test_optimize_select_Tensor (self ):
328
452
class SelectTensor (torch .nn .Module ):
@@ -387,6 +511,7 @@ def forward(self, x, y, z):
387
511
self .assertEqual (
388
512
count_node (graph_module , torch .ops .aten ._select_copy_nop .int_out ), 3
389
513
)
514
+ self .verify_nop_memory_alloc (graph_module )
390
515
391
516
# TODO: Test fails due to memory planning
392
517
@unittest .expectedFailure
@@ -416,6 +541,32 @@ def forward(self, x, y):
416
541
graph_module .graph .eliminate_dead_code ()
417
542
# Assert that cat op is not optimized away
418
543
self .assertEqual (count_node (graph_module , exir_ops .edge .aten .cat .default ), 1 )
544
+ self .verify_nop_memory_alloc (graph_module )
545
+
546
+ def test_optimize_cat_then_slice_on_mutable_buffer (self ):
547
+ class CatWithPadding (torch .nn .Module ):
548
+ def __init__ (self , padding_shape ):
549
+ super ().__init__ ()
550
+ zeros = torch .zeros (padding_shape )
551
+ self .register_buffer ("padding" , zeros )
552
+
553
+ def forward (self , x , y ):
554
+ x = x .view (3 , 5 )
555
+ cat = torch .ops .aten .cat ((x , self .padding .clone ()))
556
+ slice_copy = torch .ops .aten .slice (cat , dim = 0 , start = x .shape [0 ])
557
+ self .padding .copy_ (slice_copy )
558
+ return cat .view (- 1 ) + y
559
+
560
+ x = torch .ones (15 )
561
+ y = torch .ones (1 )
562
+ et_prog_manager = compiler .export_to_executorch_gen_etrecord (
563
+ CatWithPadding ((1 , 5 )), (x , y ), opt_level = 3
564
+ )
565
+ graph_module = et_prog_manager .exported_program ().graph_module
566
+ logging .info (f"graph_module: { graph_module .print_readable (print_output = False )} " )
567
+ self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
568
+ self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
569
+ self .verify_nop_memory_alloc (graph_module )
419
570
420
571
def test_optimize_cat_with_view (self ):
421
572
class CatViewFeasible (torch .nn .Module ):
@@ -442,6 +593,7 @@ def forward(self, x, y):
442
593
# Assert that cat op is optimized away
443
594
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
444
595
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
596
+ self .verify_nop_memory_alloc (graph_module )
445
597
446
598
def test_no_optimize_cat_with_repeated_args (self ):
447
599
class CatViewInfeasible (torch .nn .Module ):
@@ -465,6 +617,7 @@ def forward(self, x):
465
617
# Assert that cat op is not optimized away
466
618
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
467
619
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
620
+ self .verify_nop_memory_alloc (graph_module )
468
621
469
622
def test_no_optimize_cat_with_placeholder (self ):
470
623
class CatViewInfeasible (torch .nn .Module ):
@@ -492,6 +645,7 @@ def forward(self, x, y):
492
645
# Assert that cat op is not optimized away
493
646
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
494
647
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
648
+ self .verify_nop_memory_alloc (graph_module )
495
649
496
650
def test_no_optimize_cat (self ) -> None :
497
651
class Model (torch .nn .Module ):
@@ -522,6 +676,7 @@ def forward(self, x) -> torch.Tensor:
522
676
count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 2
523
677
)
524
678
self .assertEqual (count_node (graph_module , memory .view ), 2 )
679
+ self .verify_nop_memory_alloc (graph_module )
525
680
526
681
def test_optimize_slice_copy (self ) -> None :
527
682
class Model (torch .nn .Module ):
@@ -553,6 +708,7 @@ def forward(self, x) -> torch.Tensor:
553
708
count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 0
554
709
)
555
710
self .assertEqual (count_node (graph_module , memory .view ), 2 )
711
+ self .verify_nop_memory_alloc (graph_module )
556
712
557
713
def test_cat_then_cat (self ) -> None :
558
714
class Model (torch .nn .Module ):
@@ -579,6 +735,7 @@ def forward(self, x) -> torch.Tensor:
579
735
graph_module .print_readable ()
580
736
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 2 )
581
737
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
738
+ self .verify_nop_memory_alloc (graph_module )
582
739
583
740
def test_view_for_unallocated_output (self ):
584
741
class Model (torch .nn .Module ):
@@ -602,3 +759,4 @@ def forward(self, x, y):
602
759
.graph_module
603
760
)
604
761
self .assertEqual (count_node (graph_module , memory .view ), 1 )
762
+ self .verify_nop_memory_alloc (graph_module )
0 commit comments