63
63
DoubleList ,
64
64
EValue ,
65
65
ExecutionPlan ,
66
+ ExtraTensorInfo ,
66
67
FreeCall ,
67
68
Instruction ,
68
69
Int ,
76
77
ScalarType ,
77
78
String ,
78
79
Tensor ,
80
+ TensorDataLocation ,
79
81
TensorList ,
80
82
TensorShapeDynamism ,
81
83
)
@@ -121,6 +123,14 @@ class _ProgramState:
121
123
# and should be copied to Program.backend_delegate_data.
122
124
backend_delegate_data : List [BackendDelegateInlineData ] = field (default_factory = list )
123
125
126
+ # Constants are optionally stored in external files.
127
+ # Aggregate unique external constants into one buffer.
128
+ external_constant_buffer : List [bytes ] = field (default_factory = list )
129
+ external_constant_hash : Dict [str , int ] = field (default_factory = dict )
130
+ # Each constant_tag groups a set of constants together.
131
+ # {constant_tag: {fqn: index into external_constant_buffer}}
132
+ external_constant_map : Dict [str , Dict [str , int ]] = field (default_factory = dict )
133
+
124
134
125
135
@dataclass
126
136
class _EmitterState :
@@ -363,7 +373,8 @@ def _save_new_const_tensor(
363
373
spec : TensorSpec ,
364
374
buffer_data : bytes ,
365
375
hashed : str ,
366
- allocation_info : Optional [AllocationDetails ],
376
+ allocation_info : Optional [AllocationDetails ] = None ,
377
+ constant_tag : Optional [str ] = None ,
367
378
) -> int :
368
379
"""Saves a new constant tensor to the constant buffer and returns the buffer idx"""
369
380
@@ -372,17 +383,45 @@ def _save_new_const_tensor(
372
383
373
384
# Update buffer_idx to point to the end of the list where we are adding the new buffer.
374
385
buffer = Buffer (storage = buffer_data )
386
+
387
+ # Tensor is mutable with initial state.
375
388
if allocation_info :
376
389
buffer_idx = len (self .program_state .mutable_buffer )
377
390
self .program_state .cached_spec_mutable_hash_values [hashed ] = buffer_idx
378
391
self .program_state .mutable_buffer .append (buffer )
392
+
393
+ # Tensor is constant.
379
394
else :
380
- buffer_idx = len (self .program_state .constant_buffer )
381
- self .program_state .cached_spec_hash_values [hashed ] = buffer_idx
382
- self .program_state .constant_buffer .append (buffer )
395
+ # Tensor is stored outside of the PTE file.
396
+ if (
397
+ spec .extra_tensor_info is not None
398
+ and spec .extra_tensor_info .fully_qualified_name is not None
399
+ and spec .extra_tensor_info .location == TensorDataLocation .EXTERNAL
400
+ ):
401
+ assert (
402
+ constant_tag is not None
403
+ ), "Constant tag is not set for external tensor"
404
+
405
+ buffer_idx = len (self .program_state .external_constant_buffer )
406
+ self .program_state .external_constant_hash [hashed ] = buffer_idx
407
+ self .program_state .external_constant_buffer .append (buffer_data )
408
+ if constant_tag not in self .program_state .external_constant_map :
409
+ self .program_state .external_constant_map [constant_tag ] = {}
410
+ self .program_state .external_constant_map [constant_tag ][
411
+ spec .extra_tensor_info .fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
412
+ ] = buffer_idx
413
+
414
+ # Tensor is stored in the PTE file.
415
+ else :
416
+ buffer_idx = len (self .program_state .constant_buffer )
417
+ self .program_state .cached_spec_hash_values [hashed ] = buffer_idx
418
+ self .program_state .constant_buffer .append (buffer )
419
+
383
420
return buffer_idx
384
421
385
- def _tensor_spec_to_evalue (self , spec : TensorSpec ) -> EValue :
422
+ def _tensor_spec_to_evalue (
423
+ self , spec : TensorSpec , constant_tag : Optional [str ] = None
424
+ ) -> EValue :
386
425
"""Constructs an EValue from the given TensorSpec."""
387
426
388
427
allocation_info = None
@@ -420,13 +459,18 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
420
459
buffer_idx = self .program_state .cached_spec_mutable_hash_values .get (
421
460
hashed , - 1
422
461
)
462
+ elif (
463
+ spec .extra_tensor_info is not None
464
+ and spec .extra_tensor_info .location == TensorDataLocation .EXTERNAL
465
+ ):
466
+ buffer_idx = self .program_state .external_constant_hash .get (hashed , - 1 )
423
467
else :
424
468
buffer_idx = self .program_state .cached_spec_hash_values .get (hashed , - 1 )
425
469
426
470
# Haven't seen this constant before.
427
471
if buffer_idx == - 1 :
428
472
buffer_idx = self ._save_new_const_tensor (
429
- spec , buffer_data , hashed , allocation_info
473
+ spec , buffer_data , hashed , allocation_info , constant_tag
430
474
)
431
475
432
476
if spec .const and spec .nbytes () != len (buffer_data ):
@@ -1557,11 +1601,26 @@ def placeholder(
1557
1601
https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder
1558
1602
"""
1559
1603
spec = self .node .meta ["spec" ]
1604
+ constant_tag = self .node .meta .get ("constant_tag" , None )
1560
1605
is_user_input = True
1561
1606
1562
1607
if isinstance (target , str ) and isinstance (spec , TensorSpec ):
1563
1608
fqn , is_mutable_buffer = self ._find_fqn_for_placeholder (target , spec )
1564
1609
1610
+ # If the placeholder has a constant_tag, it is external to the PTE file
1611
+ # and requires a fqn and location=TensorDataLocation.EXTERNAL
1612
+ if constant_tag is not None :
1613
+ assert (
1614
+ fqn is not None
1615
+ ), "constant tagged tensors require a fully qualified name"
1616
+ if spec .extra_tensor_info is None :
1617
+ spec .extra_tensor_info = ExtraTensorInfo (
1618
+ fully_qualified_name = fqn , location = TensorDataLocation .EXTERNAL
1619
+ )
1620
+ else :
1621
+ spec .extra_tensor_info .fully_qualified_name = fqn
1622
+ spec .extra_tensor_info .location = TensorDataLocation .EXTERNAL
1623
+
1565
1624
# From the fqn find the corresponding tensor
1566
1625
real_tensor = None
1567
1626
if fqn in self .exported_program .state_dict :
@@ -1599,7 +1658,7 @@ def placeholder(
1599
1658
spec .const = not (is_user_input or is_mutable_buffer )
1600
1659
1601
1660
evalue = (
1602
- self ._tensor_spec_to_evalue (spec )
1661
+ self ._tensor_spec_to_evalue (spec , constant_tag )
1603
1662
if isinstance (spec , TensorSpec )
1604
1663
else self ._constant_to_evalue (spec , None )
1605
1664
)
0 commit comments