63
63
DoubleList ,
64
64
EValue ,
65
65
ExecutionPlan ,
66
+ ExtraTensorInfo ,
66
67
FreeCall ,
67
68
Instruction ,
68
69
Int ,
@@ -121,6 +122,14 @@ class _ProgramState:
121
122
# and should be copied to Program.backend_delegate_data.
122
123
backend_delegate_data : List [BackendDelegateInlineData ] = field (default_factory = list )
123
124
125
+ # Constants are optionally stored in external files.
126
+ # Aggregate unique external constants into one buffer.
127
+ external_constant_buffer : List [bytes ] = field (default_factory = list )
128
+ external_constant_hash : Dict [str , int ] = field (default_factory = dict )
129
+ # Each constant_tag groups a set of constants together.
130
+ # {constant_tag: {fqn: index into external_constant_buffer}}
131
+ external_constant_map : Dict [str , Dict [str , int ]] = field (default_factory = dict )
132
+
124
133
125
134
@dataclass
126
135
class _EmitterState :
@@ -364,6 +373,7 @@ def _save_new_const_tensor(
364
373
buffer_data : bytes ,
365
374
hashed : str ,
366
375
allocation_info : Optional [AllocationDetails ],
376
+ constant_tag : str ,
367
377
) -> int :
368
378
"""Saves a new constant tensor to the constant buffer and returns the buffer idx"""
369
379
@@ -372,17 +382,45 @@ def _save_new_const_tensor(
372
382
373
383
# Update buffer_idx to point to the end of the list where we are adding the new buffer.
374
384
buffer = Buffer (storage = buffer_data )
385
+
386
+ # Tensor is mutable with initial state.
375
387
if allocation_info :
376
388
buffer_idx = len (self .program_state .mutable_buffer )
377
389
self .program_state .cached_spec_mutable_hash_values [hashed ] = buffer_idx
378
390
self .program_state .mutable_buffer .append (buffer )
391
+
392
+ # Tensor is constant.
379
393
else :
380
- buffer_idx = len (self .program_state .constant_buffer )
381
- self .program_state .cached_spec_hash_values [hashed ] = buffer_idx
382
- self .program_state .constant_buffer .append (buffer )
394
+ # Tensor is stored outside of the PTE file.
395
+ if (
396
+ spec .extra_tensor_info is not None
397
+ and spec .extra_tensor_info .fully_qualified_name is not None
398
+ and spec .extra_tensor_info .location == DataLocation .EXTERNAL
399
+ ):
400
+ assert (
401
+ constant_tag is not None
402
+ ), "Constant tag is not set for external tensor"
403
+
404
+ buffer_idx = len (self .program_state .external_constant_buffer )
405
+ self .program_state .external_constant_hash [hashed ] = buffer_idx
406
+ self .program_state .external_constant_buffer .append (buffer_data )
407
+ if constant_tag not in self .program_state .external_constant_map :
408
+ self .program_state .external_constant_map [constant_tag ] = {}
409
+ self .program_state .external_constant_map [constant_tag ][
410
+ spec .extra_tensor_info .fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
411
+ ] = buffer_idx
412
+
413
+ # Tensor is stored in the PTE file.
414
+ else :
415
+ buffer_idx = len (self .program_state .constant_buffer )
416
+ self .program_state .cached_spec_hash_values [hashed ] = buffer_idx
417
+ self .program_state .constant_buffer .append (buffer )
418
+
383
419
return buffer_idx
384
420
385
- def _tensor_spec_to_evalue (self , spec : TensorSpec ) -> EValue :
421
+ def _tensor_spec_to_evalue (
422
+ self , spec : TensorSpec , constant_tag : Optional [str ] = None
423
+ ) -> EValue :
386
424
"""Constructs an EValue from the given TensorSpec."""
387
425
388
426
allocation_info = None
@@ -420,13 +458,18 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
420
458
buffer_idx = self .program_state .cached_spec_mutable_hash_values .get (
421
459
hashed , - 1
422
460
)
461
+ elif (
462
+ spec .extra_tensor_info is not None
463
+ and spec .extra_tensor_info .location == DataLocation .EXTERNAL
464
+ ):
465
+ buffer_idx = self .program_state .external_constant_hash .get (hashed , - 1 )
423
466
else :
424
467
buffer_idx = self .program_state .cached_spec_hash_values .get (hashed , - 1 )
425
468
426
469
# Haven't seen this constant before.
427
470
if buffer_idx == - 1 :
428
471
buffer_idx = self ._save_new_const_tensor (
429
- spec , buffer_data , hashed , allocation_info
472
+ spec , buffer_data , hashed , allocation_info , constant_tag
430
473
)
431
474
432
475
if spec .const and spec .nbytes () != len (buffer_data ):
@@ -1557,11 +1600,26 @@ def placeholder(
1557
1600
https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder
1558
1601
"""
1559
1602
spec = self .node .meta ["spec" ]
1603
+ constant_tag = self .node .meta .get ("constant_tag" , None )
1560
1604
is_user_input = True
1561
1605
1562
1606
if isinstance (target , str ) and isinstance (spec , TensorSpec ):
1563
1607
fqn , is_mutable_buffer = self ._find_fqn_for_placeholder (target , spec )
1564
1608
1609
+ # If the placeholder has a constant_tag, it is external to the PTE file
1610
+ # and requires a fqn and location=DataLocation.EXTERNAL
1611
+ if constant_tag is not None :
1612
+ assert (
1613
+ fqn is not None
1614
+ ), "constant tagged tensors require a fully qualified name"
1615
+ if spec .extra_tensor_info is None :
1616
+ spec .extra_tensor_info = ExtraTensorInfo (
1617
+ fully_qualified_name = fqn , location = DataLocation .EXTERNAL
1618
+ )
1619
+ else :
1620
+ spec .extra_tensor_info .fully_qualified_name = fqn
1621
+ spec .extra_tensor_info .location = DataLocation .EXTERNAL
1622
+
1565
1623
# From the fqn find the corresponding tensor
1566
1624
real_tensor = None
1567
1625
if fqn in self .exported_program .state_dict :
@@ -1599,7 +1657,7 @@ def placeholder(
1599
1657
spec .const = not (is_user_input or is_mutable_buffer )
1600
1658
1601
1659
evalue = (
1602
- self ._tensor_spec_to_evalue (spec )
1660
+ self ._tensor_spec_to_evalue (spec , constant_tag )
1603
1661
if isinstance (spec , TensorSpec )
1604
1662
else self ._constant_to_evalue (spec , None )
1605
1663
)
0 commit comments