32
32
import hashlib
33
33
import operator
34
34
import typing
35
+ import warnings
35
36
from dataclasses import dataclass , field
36
- from typing import Callable , cast , Dict , List , Mapping , Optional , Tuple , Union
37
+ from typing import Any , Callable , cast , Dict , List , Mapping , Optional , Tuple , Union
37
38
38
39
import executorch .exir .memory as memory
39
40
import executorch .extension .pytree as ex_pytree
@@ -1266,15 +1267,17 @@ def __init__(
1266
1267
self ,
1267
1268
name : str ,
1268
1269
exported_program : ExportedProgram ,
1270
+ graph_module : torch .fx .GraphModule ,
1269
1271
program_state : _ProgramState ,
1270
1272
emitter_state : _EmitterState ,
1271
1273
) -> None :
1272
- super ().__init__ (exported_program . graph_module , emitter_state , program_state )
1274
+ super ().__init__ (graph_module , emitter_state , program_state )
1273
1275
self .name = name
1274
1276
self .exported_program = exported_program
1275
1277
1276
1278
self .inputs : List [int ] = []
1277
1279
self .outputs : List [int ] = []
1280
+ self .given_mutable_buffer_warning = False
1278
1281
1279
1282
def create_container_str (spec : Optional [pytree .TreeSpec ]) -> str :
1280
1283
if spec is None :
@@ -1293,6 +1296,42 @@ def create_container_str(spec: Optional[pytree.TreeSpec]) -> str:
1293
1296
inp_container_str , out_container_str
1294
1297
)
1295
1298
1299
+ def _find_fqn_for_placeholder (
1300
+ self , target : _Target , spec : Any # pyre-ignore[2]
1301
+ ) -> Tuple [Optional [str ], bool ]:
1302
+ # Find the fully qualified name
1303
+ fqn = None
1304
+ is_mutable_buffer = False
1305
+ if target in self .exported_program .graph_signature .inputs_to_parameters :
1306
+ fqn = self .exported_program .graph_signature .inputs_to_parameters [target ]
1307
+
1308
+ elif target in self .exported_program .graph_signature .inputs_to_buffers :
1309
+ fqn = self .exported_program .graph_signature .inputs_to_buffers [target ]
1310
+
1311
+ # if the buffer is mutated then record that
1312
+ if fqn in self .exported_program .graph_signature .buffers_to_mutate .values ():
1313
+ is_mutable_buffer = True
1314
+ if not self .given_mutable_buffer_warning :
1315
+ warnings .warn (
1316
+ "Mutation on a buffer in the model is detected. ExecuTorch assumes "
1317
+ "buffers that are mutated in the graph have a meaningless initial state, "
1318
+ "only the shape and dtype will be serialized." ,
1319
+ UserWarning ,
1320
+ stacklevel = 1 ,
1321
+ )
1322
+ self .given_mutable_buffer_warning = True
1323
+
1324
+ elif (
1325
+ target
1326
+ in self .exported_program .graph_signature .inputs_to_lifted_tensor_constants
1327
+ ):
1328
+ fqn = (
1329
+ self .exported_program .graph_signature .inputs_to_lifted_tensor_constants [
1330
+ target
1331
+ ]
1332
+ )
1333
+ return fqn , is_mutable_buffer
1334
+
1296
1335
def placeholder (
1297
1336
self , target : _Target , args : Tuple [_Argument , ...], kwargs : Dict [str , _Argument ]
1298
1337
) -> _AbstractValue :
@@ -1302,40 +1341,27 @@ def placeholder(
1302
1341
https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder
1303
1342
"""
1304
1343
spec = self .node .meta ["spec" ]
1305
- const_tensor = False
1306
- if isinstance (target , str ) and (
1307
- target in self .exported_program .graph_signature .inputs_to_parameters
1308
- or target in self .exported_program .graph_signature .inputs_to_buffers
1309
- or target
1310
- in self .exported_program .graph_signature .inputs_to_lifted_tensor_constants
1311
- ):
1312
- if (
1313
- target
1314
- in self .exported_program .graph_signature .inputs_to_lifted_tensor_constants
1315
- ):
1316
- fqn = self .exported_program .graph_signature .inputs_to_lifted_tensor_constants [
1317
- target
1318
- ]
1319
- elif target in self .exported_program .graph_signature .inputs_to_buffers :
1320
- fqn = self .exported_program .graph_signature .inputs_to_buffers [target ]
1321
- else :
1322
- fqn = self .exported_program .graph_signature .inputs_to_parameters [target ]
1344
+ is_user_input = True
1345
+
1346
+ if isinstance (target , str ) and isinstance (spec , TensorSpec ):
1347
+
1348
+ fqn , is_mutable_buffer = self ._find_fqn_for_placeholder (target , spec )
1349
+
1350
+ # From the fqn find the corresponding tensor
1351
+ real_tensor = None
1323
1352
if fqn in self .exported_program .state_dict :
1324
- spec = TensorSpec .from_tensor (
1325
- self .exported_program .state_dict [fqn ], const = True
1326
- )
1327
- const_tensor = True
1353
+ real_tensor = self .exported_program .state_dict [fqn ]
1354
+ is_user_input = False
1355
+
1328
1356
elif fqn in self .exported_program .constants :
1329
- spec = TensorSpec .from_tensor (
1330
- self .exported_program .constants [fqn ], const = True
1331
- )
1332
- const_tensor = True
1333
- else :
1357
+ real_tensor = self .exported_program .constants [fqn ]
1358
+ is_user_input = False
1359
+ elif fqn is not None :
1334
1360
buffers = self .exported_program .named_buffers ()
1335
1361
buf = next ((x [1 ] for x in buffers if x [0 ] == fqn ), None )
1336
1362
if buf is not None :
1337
- spec = TensorSpec . from_tensor ( buf , const = True )
1338
- const_tensor = True
1363
+ real_tensor = buf
1364
+ is_user_input = False
1339
1365
else :
1340
1366
raise InternalError (
1341
1367
self ._emit_node_specific_error (
@@ -1344,13 +1370,28 @@ def placeholder(
1344
1370
)
1345
1371
)
1346
1372
1373
+ # assign the storage of the placeholder spec to the storage of the real tensor if there is one
1374
+ if real_tensor is not None :
1375
+ # for non-contigous tensors, convert to a contiguous one
1376
+ real_tensor = real_tensor .contiguous ()
1377
+ # Weights cannot be views during emission or serialization
1378
+ if real_tensor .nbytes != real_tensor .untyped_storage ().nbytes ():
1379
+ real_tensor = real_tensor .clone ()
1380
+
1381
+ spec .storage = real_tensor .untyped_storage ()
1382
+
1383
+ # User inputs and mutable buffers are not constants, other buffers or parameters are.
1384
+ spec .const = not (is_user_input or is_mutable_buffer )
1385
+
1347
1386
evalue = (
1348
1387
self ._tensor_spec_to_evalue (spec )
1349
1388
if isinstance (spec , TensorSpec )
1350
1389
else self ._constant_to_evalue (spec , None )
1351
1390
)
1352
1391
value = self ._emit_evalue (evalue )
1353
- if not const_tensor :
1392
+
1393
+ # Only user inputs should remain as inputs.
1394
+ if is_user_input :
1354
1395
self .inputs .append (value .id )
1355
1396
1356
1397
return value
0 commit comments