Skip to content

Commit 37363c9

Browse files
committed
[executorch][emitter] Emit FQNs
Pull Request resolved: #7192 Emit FQNs for external tensors. In the emitter, store external tensors as: ``` // list of unique tensors external_constants_buffer: List[bytes] // map of {constant_tag: {fqn: index into external_constant_buffer}} // constant_tag: may want to save multiple external constant files; group them together via the tag. // {fqn: index}; there may be multiple fqns pointing to the same data buffer. This is for deduplication. external_constants_map: [Dict[str, Dict[str, int]] ``` ghstack-source-id: 257245125 @exported-using-ghexport Differential Revision: [D66523226](https://our.internmc.facebook.com/intern/diff/D66523226/)
1 parent b9db0a3 commit 37363c9

File tree

2 files changed

+75
-7
lines changed

2 files changed

+75
-7
lines changed

exir/emit/_emit_program.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ class EmitterOutput:
4747

4848
mutable_data: Optional[List[Buffer]]
4949

50+
# Constants are optionally stored in external files.
51+
# Aggregate unique external constants into one buffer.
52+
external_constant_buffer: List[bytes]
53+
# Each constant_tag groups a set of constants together.
54+
# {constant_tag: {fqn: index into external_constant_buffer}}
55+
external_constant_map: Optional[Dict[str, Dict[str, int]]]
56+
5057

5158
def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule:
5259
gm = exported_program.graph_module
@@ -199,4 +206,6 @@ def emit_program(
199206
if len(program_state.mutable_buffer) > 1
200207
else None
201208
),
209+
external_constant_buffer=program_state.external_constant_buffer,
210+
external_constant_map=program_state.external_constant_map,
202211
)

exir/emit/_emitter.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
DoubleList,
6464
EValue,
6565
ExecutionPlan,
66+
ExtraTensorInfo,
6667
FreeCall,
6768
Instruction,
6869
Int,
@@ -76,6 +77,7 @@
7677
ScalarType,
7778
String,
7879
Tensor,
80+
TensorDataLocation,
7981
TensorList,
8082
TensorShapeDynamism,
8183
)
@@ -121,6 +123,14 @@ class _ProgramState:
121123
# and should be copied to Program.backend_delegate_data.
122124
backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list)
123125

126+
# Constants are optionally stored in external files.
127+
# Aggregate unique external constants into one buffer.
128+
external_constant_buffer: List[bytes] = field(default_factory=list)
129+
external_constant_hash: Dict[str, int] = field(default_factory=dict)
130+
# Each constant_tag groups a set of constants together.
131+
# {constant_tag: {fqn: index into external_constant_buffer}}
132+
external_constant_map: Dict[str, Dict[str, int]] = field(default_factory=dict)
133+
124134

125135
@dataclass
126136
class _EmitterState:
@@ -363,7 +373,8 @@ def _save_new_const_tensor(
363373
spec: TensorSpec,
364374
buffer_data: bytes,
365375
hashed: str,
366-
allocation_info: Optional[AllocationDetails],
376+
allocation_info: Optional[AllocationDetails] = None,
377+
constant_tag: Optional[str] = None,
367378
) -> int:
368379
"""Saves a new constant tensor to the constant buffer and returns the buffer idx"""
369380

@@ -372,17 +383,45 @@ def _save_new_const_tensor(
372383

373384
# Update buffer_idx to point to the end of the list where we are adding the new buffer.
374385
buffer = Buffer(storage=buffer_data)
386+
387+
# Tensor is mutable with initial state.
375388
if allocation_info:
376389
buffer_idx = len(self.program_state.mutable_buffer)
377390
self.program_state.cached_spec_mutable_hash_values[hashed] = buffer_idx
378391
self.program_state.mutable_buffer.append(buffer)
392+
393+
# Tensor is constant.
379394
else:
380-
buffer_idx = len(self.program_state.constant_buffer)
381-
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
382-
self.program_state.constant_buffer.append(buffer)
395+
# Tensor is stored outside of the PTE file.
396+
if (
397+
spec.extra_tensor_info is not None
398+
and spec.extra_tensor_info.fully_qualified_name is not None
399+
and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL
400+
):
401+
assert (
402+
constant_tag is not None
403+
), "Constant tag is not set for external tensor"
404+
405+
buffer_idx = len(self.program_state.external_constant_buffer)
406+
self.program_state.external_constant_hash[hashed] = buffer_idx
407+
self.program_state.external_constant_buffer.append(buffer_data)
408+
if constant_tag not in self.program_state.external_constant_map:
409+
self.program_state.external_constant_map[constant_tag] = {}
410+
self.program_state.external_constant_map[constant_tag][
411+
spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
412+
] = buffer_idx
413+
414+
# Tensor is stored in the PTE file.
415+
else:
416+
buffer_idx = len(self.program_state.constant_buffer)
417+
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
418+
self.program_state.constant_buffer.append(buffer)
419+
383420
return buffer_idx
384421

385-
def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
422+
def _tensor_spec_to_evalue(
423+
self, spec: TensorSpec, constant_tag: Optional[str] = None
424+
) -> EValue:
386425
"""Constructs an EValue from the given TensorSpec."""
387426

388427
allocation_info = None
@@ -420,13 +459,18 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
420459
buffer_idx = self.program_state.cached_spec_mutable_hash_values.get(
421460
hashed, -1
422461
)
462+
elif (
463+
spec.extra_tensor_info is not None
464+
and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL
465+
):
466+
buffer_idx = self.program_state.external_constant_hash.get(hashed, -1)
423467
else:
424468
buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1)
425469

426470
# Haven't seen this constant before.
427471
if buffer_idx == -1:
428472
buffer_idx = self._save_new_const_tensor(
429-
spec, buffer_data, hashed, allocation_info
473+
spec, buffer_data, hashed, allocation_info, constant_tag
430474
)
431475

432476
if spec.const and spec.nbytes() != len(buffer_data):
@@ -1557,11 +1601,26 @@ def placeholder(
15571601
https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder
15581602
"""
15591603
spec = self.node.meta["spec"]
1604+
constant_tag = self.node.meta.get("constant_tag", None)
15601605
is_user_input = True
15611606

15621607
if isinstance(target, str) and isinstance(spec, TensorSpec):
15631608
fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec)
15641609

1610+
# If the placeholder has a constant_tag, it is external to the PTE file
1611+
# and requires a fqn and location=TensorDataLocation.EXTERNAL
1612+
if constant_tag is not None:
1613+
assert (
1614+
fqn is not None
1615+
), "constant tagged tensors require a fully qualified name"
1616+
if spec.extra_tensor_info is None:
1617+
spec.extra_tensor_info = ExtraTensorInfo(
1618+
fully_qualified_name=fqn, location=TensorDataLocation.EXTERNAL
1619+
)
1620+
else:
1621+
spec.extra_tensor_info.fully_qualified_name = fqn
1622+
spec.extra_tensor_info.location = TensorDataLocation.EXTERNAL
1623+
15651624
# From the fqn find the corresponding tensor
15661625
real_tensor = None
15671626
if fqn in self.exported_program.state_dict:
@@ -1599,7 +1658,7 @@ def placeholder(
15991658
spec.const = not (is_user_input or is_mutable_buffer)
16001659

16011660
evalue = (
1602-
self._tensor_spec_to_evalue(spec)
1661+
self._tensor_spec_to_evalue(spec, constant_tag)
16031662
if isinstance(spec, TensorSpec)
16041663
else self._constant_to_evalue(spec, None)
16051664
)

0 commit comments

Comments
 (0)