@@ -519,6 +519,39 @@ def test_multiple_pools(
519
519
idx += 1
520
520
self .assertEqual (graph_module .meta ["non_const_buffer_sizes" ], expected_bufsizes )
521
521
522
+ def test_mutation_not_double_allocated (self ) -> None :
523
+ class Simple (torch .nn .Module ):
524
+ def __init__ (self ) -> None :
525
+ super ().__init__ ()
526
+ self .register_buffer ("constant" , torch .ones (5 , 5 ))
527
+
528
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
529
+ self .constant .add_ (1 )
530
+ return x - self .constant
531
+
532
+ model = Simple ()
533
+ inputs = (torch .ones (5 , 5 ),)
534
+
535
+ et = to_edge (export (model , inputs , strict = True )).to_executorch ()
536
+
537
+ # 0 and 11 should refer to the same tensor. 0 is the input, 11 is the output of copy_
538
+ self .assertEqual (
539
+ et .executorch_program .execution_plan [0 ]
540
+ .values [0 ]
541
+ .val .allocation_info .memory_offset_low ,
542
+ et .executorch_program .execution_plan [0 ]
543
+ .values [11 ]
544
+ .val .allocation_info .memory_offset_low ,
545
+ )
546
+ self .assertEqual (
547
+ et .executorch_program .execution_plan [0 ]
548
+ .values [0 ]
549
+ .val .allocation_info .memory_offset_high ,
550
+ et .executorch_program .execution_plan [0 ]
551
+ .values [11 ]
552
+ .val .allocation_info .memory_offset_high ,
553
+ )
554
+
522
555
def test_constants_not_memory_planned (self ) -> None :
523
556
class Simple (torch .nn .Module ):
524
557
def __init__ (self ) -> None :
0 commit comments