@@ -1291,36 +1291,41 @@ class MutableStateModule(torch.nn.Module):
1291
1291
def __init__ (self ):
1292
1292
super ().__init__ ()
1293
1293
self .register_buffer ("state" , torch .zeros (1 ))
1294
+ self .register_buffer ("direct_copy_from_input" , torch .zeros (1 ))
1294
1295
1295
1296
def forward (self , x ):
1296
1297
y = x + self .state
1297
1298
self .state .add_ (1 )
1299
+ self .direct_copy_from_input .copy_ (x )
1298
1300
return y
1299
1301
1300
1302
model = to_edge (export (MutableStateModule (), (torch .zeros (1 ),), strict = True ))
1301
1303
self .assertEqual (count_copies (model .exported_program ().graph_module ), 0 )
1302
1304
# Before
1303
1305
# graph():
1304
- # %arg0_1 : [num_users=2] = placeholder[target=arg0_1]
1305
- # %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
1306
- # %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
1307
- # %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
1308
- # %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
1309
- # %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
1310
- # return (aten_add_tensor_1, aten_add_tensor)
1306
+ # %b_state : [num_users=2] = placeholder[target=b_state]
1307
+ # %b_direct_copy_from_input : [num_users=0] = placeholder[target=b_direct_copy_from_input]
1308
+ # %_lifted_tensor_constant2 : [num_users=1] = placeholder[target=_lifted_tensor_constant2]
1309
+ # %x : [num_users=2] = placeholder[target=x]
1310
+ # %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%x, %b_state), kwargs = {})
1311
+ # %dim_order_ops__to_dim_order_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.dim_order_ops._to_dim_order_copy.default](args = (%_lifted_tensor_constant2,), kwargs = {dtype: torch.float32, dim_order: []})
1312
+ # %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%b_state, %dim_order_ops__to_dim_order_copy_default), kwargs = {})
1313
+ # return (aten_add_tensor_1, x, aten_add_tensor)
1311
1314
gm , _ = insert_write_back_for_buffers_pass (model .exported_program ())
1312
1315
1313
1316
# After
1314
1317
# graph():
1315
- # %arg0_1 : [num_users=3] = placeholder[target=arg0_1]
1316
- # %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
1317
- # %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
1318
- # %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
1319
- # %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
1320
- # %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
1321
- # %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
1322
- # return (copy__default, aten_add_tensor)
1323
- self .assertEqual (count_copies (gm ), 1 )
1318
+ # %b_state : [num_users=3] = placeholder[target=b_state]
1319
+ # %b_direct_copy_from_input : [num_users=1] = placeholder[target=b_direct_copy_from_input]
1320
+ # %_lifted_tensor_constant2 : [num_users=1] = placeholder[target=_lifted_tensor_constant2]
1321
+ # %x : [num_users=2] = placeholder[target=x]
1322
+ # %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%x, %b_state), kwargs = {})
1323
+ # %dim_order_ops__to_dim_order_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.dim_order_ops._to_dim_order_copy.default](args = (%_lifted_tensor_constant2,), kwargs = {dtype: torch.float32, dim_order: []})
1324
+ # %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%b_state, %dim_order_ops__to_dim_order_copy_default), kwargs = {})
1325
+ # %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%b_state, %aten_add_tensor_1), kwargs = {})
1326
+ # %copy__default_1 : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%b_direct_copy_from_input, %x), kwargs = {})
1327
+ # return (copy__default, copy__default_1, aten_add_tensor)
1328
+ self .assertEqual (count_copies (gm ), 2 )
1324
1329
1325
1330
def test_remove_quantized_op_noop_pass (self ) -> None :
1326
1331
class TestAddSliceNoop (torch .nn .Module ):
0 commit comments