@@ -1534,26 +1534,49 @@ def forward(self, x):
1534
1534
self .assertEqual (len (program .constant_buffer [1 ].storage ), 8 )
1535
1535
1536
1536
def test_emit_lifted_tensor_constant (self ) -> None :
1537
- class LiftedConstants (nn .Module ):
1537
+ class LiftedTensorConstants (nn .Module ):
1538
1538
def __init__ (self ):
1539
1539
super ().__init__ ()
1540
1540
1541
1541
def forward (self , x ):
1542
1542
x = x * torch .tensor ([[4 , 3 ], [1 , 2 ], [5 , 6 ]], dtype = torch .float )
1543
1543
return x
1544
1544
1545
- model = LiftedConstants ()
1545
+ model = LiftedTensorConstants ()
1546
+ # Specify that we want to move non-lifted constants to external file
1547
+ et_cfg = ExecutorchBackendConfig (external_constants = True )
1548
+ program = to_edge (
1549
+ export (model , (torch .ones (3 , 2 ),), strict = True )
1550
+ ).to_executorch (et_cfg )
1551
+ program = program ._emitter_output .program
1552
+ exec_plan = program .execution_plan [0 ]
1553
+ # There should only be 1 input to this model.
1554
+ self .assertEqual (len (exec_plan .inputs ), 1 )
1555
+ self .assertEqual (len (program .constant_buffer ), 2 )
1556
+ self .assertEqual (len (program .constant_buffer [1 ].storage ), 24 )
1546
1557
1558
+ def test_emit_lifted_constant (self ) -> None :
1559
+ class LiftedConstants (nn .Module ):
1560
+ def __init__ (self ):
1561
+ super ().__init__ ()
1562
+
1563
+ def forward (self , x ):
1564
+ x = x + 1
1565
+ return x
1566
+
1567
+ model = LiftedConstants ()
1568
+ # Specify that we want to move non-lifted constants to external file
1569
+ et_cfg = ExecutorchBackendConfig (external_constants = True )
1547
1570
program = to_edge (
1548
1571
export (model , (torch .ones (3 , 2 ),), strict = True )
1549
- ).to_executorch ()
1572
+ ).to_executorch (et_cfg )
1550
1573
1551
1574
program = program ._emitter_output .program
1552
1575
exec_plan = program .execution_plan [0 ]
1553
1576
# There should only be 1 input to this model.
1554
1577
self .assertEqual (len (exec_plan .inputs ), 1 )
1555
1578
self .assertEqual (len (program .constant_buffer ), 2 )
1556
- self .assertEqual (len (program .constant_buffer [1 ].storage ), 24 )
1579
+ self .assertEqual (len (program .constant_buffer [1 ].storage ), 8 )
1557
1580
1558
1581
def test_mutable_buffers (self ) -> None :
1559
1582
def count_copies (gm : torch .fx .GraphModule ) -> int :
0 commit comments