Skip to content

Commit bedbf9d

Browse files
committed
[executorch][emitter] Emit FQNs
Pull Request resolved: #7192 Emit FQNs for external tensors. In the emitter, store external tensors as: ``` // list of unique tensors external_constants_buffer: List[bytes] // map of {constant_tag: {fqn: index into external_constant_buffer}} // constant_tag: may want to save multiple external constant files; group them together via the tag. // {fqn: index}; there may be multiple fqns pointing to the same data buffer. This is for deduplication. external_constants_map: [Dict[str, Dict[str, int]] ``` ghstack-source-id: 256979743 @exported-using-ghexport Differential Revision: [D66523226](https://our.internmc.facebook.com/intern/diff/D66523226/)
1 parent f3a1f7e commit bedbf9d

File tree

2 files changed

+73
-6
lines changed

2 files changed

+73
-6
lines changed

exir/emit/_emit_program.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ class EmitterOutput:
4747

4848
mutable_data: Optional[List[Buffer]]
4949

50+
# Constants are optionally stored in external files.
51+
# Aggregate unique external constants into one buffer.
52+
external_constant_buffer: List[bytes]
53+
# Each constant_tag groups a set of constants together.
54+
# {constant_tag: {fqn: index into external_constant_buffer}}
55+
external_constant_map: Optional[Dict[str, Dict[str, int]]]
56+
5057

5158
def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule:
5259
gm = exported_program.graph_module
@@ -199,4 +206,6 @@ def emit_program(
199206
if len(program_state.mutable_buffer) > 1
200207
else None
201208
),
209+
external_constant_buffer=program_state.external_constant_buffer,
210+
external_constant_map=program_state.external_constant_map,
202211
)

exir/emit/_emitter.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
DoubleList,
6464
EValue,
6565
ExecutionPlan,
66+
ExtraTensorInfo,
6667
FreeCall,
6768
Instruction,
6869
Int,
@@ -121,6 +122,14 @@ class _ProgramState:
121122
# and should be copied to Program.backend_delegate_data.
122123
backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list)
123124

125+
# Constants are optionally stored in external files.
126+
# Aggregate unique external constants into one buffer.
127+
external_constant_buffer: List[bytes] = field(default_factory=list)
128+
external_constant_hash: Dict[str, int] = field(default_factory=dict)
129+
# Each constant_tag groups a set of constants together.
130+
# {constant_tag: {fqn: index into external_constant_buffer}}
131+
external_constant_map: Dict[str, Dict[str, int]] = field(default_factory=dict)
132+
124133

125134
@dataclass
126135
class _EmitterState:
@@ -364,6 +373,7 @@ def _save_new_const_tensor(
364373
buffer_data: bytes,
365374
hashed: str,
366375
allocation_info: Optional[AllocationDetails],
376+
constant_tag: str,
367377
) -> int:
368378
"""Saves a new constant tensor to the constant buffer and returns the buffer idx"""
369379

@@ -372,17 +382,45 @@ def _save_new_const_tensor(
372382

373383
# Update buffer_idx to point to the end of the list where we are adding the new buffer.
374384
buffer = Buffer(storage=buffer_data)
385+
386+
# Tensor is mutable with initial state.
375387
if allocation_info:
376388
buffer_idx = len(self.program_state.mutable_buffer)
377389
self.program_state.cached_spec_mutable_hash_values[hashed] = buffer_idx
378390
self.program_state.mutable_buffer.append(buffer)
391+
392+
# Tensor is constant.
379393
else:
380-
buffer_idx = len(self.program_state.constant_buffer)
381-
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
382-
self.program_state.constant_buffer.append(buffer)
394+
# Tensor is stored outside of the PTE file.
395+
if (
396+
spec.extra_tensor_info is not None
397+
and spec.extra_tensor_info.fully_qualified_name is not None
398+
and spec.extra_tensor_info.location == DataLocation.EXTERNAL
399+
):
400+
assert (
401+
constant_tag is not None
402+
), "Constant tag is not set for external tensor"
403+
404+
buffer_idx = len(self.program_state.external_constant_buffer)
405+
self.program_state.external_constant_hash[hashed] = buffer_idx
406+
self.program_state.external_constant_buffer.append(buffer_data)
407+
if constant_tag not in self.program_state.external_constant_map:
408+
self.program_state.external_constant_map[constant_tag] = {}
409+
self.program_state.external_constant_map[constant_tag][
410+
spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
411+
] = buffer_idx
412+
413+
# Tensor is stored in the PTE file.
414+
else:
415+
buffer_idx = len(self.program_state.constant_buffer)
416+
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
417+
self.program_state.constant_buffer.append(buffer)
418+
383419
return buffer_idx
384420

385-
def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
421+
def _tensor_spec_to_evalue(
422+
self, spec: TensorSpec, constant_tag: Optional[str] = None
423+
) -> EValue:
386424
"""Constructs an EValue from the given TensorSpec."""
387425

388426
allocation_info = None
@@ -420,13 +458,18 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
420458
buffer_idx = self.program_state.cached_spec_mutable_hash_values.get(
421459
hashed, -1
422460
)
461+
elif (
462+
spec.extra_tensor_info is not None
463+
and spec.extra_tensor_info.location == DataLocation.EXTERNAL
464+
):
465+
buffer_idx = self.program_state.external_constant_hash.get(hashed, -1)
423466
else:
424467
buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1)
425468

426469
# Haven't seen this constant before.
427470
if buffer_idx == -1:
428471
buffer_idx = self._save_new_const_tensor(
429-
spec, buffer_data, hashed, allocation_info
472+
spec, buffer_data, hashed, allocation_info, constant_tag
430473
)
431474

432475
if spec.const and spec.nbytes() != len(buffer_data):
@@ -1557,11 +1600,26 @@ def placeholder(
15571600
https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder
15581601
"""
15591602
spec = self.node.meta["spec"]
1603+
constant_tag = self.node.meta.get("constant_tag", None)
15601604
is_user_input = True
15611605

15621606
if isinstance(target, str) and isinstance(spec, TensorSpec):
15631607
fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec)
15641608

1609+
# If the placeholder has a constant_tag, it is external to the PTE file
1610+
# and requires a fqn and location=DataLocation.EXTERNAL
1611+
if constant_tag is not None:
1612+
assert (
1613+
fqn is not None
1614+
), "constant tagged tensors require a fully qualified name"
1615+
if spec.extra_tensor_info is None:
1616+
spec.extra_tensor_info = ExtraTensorInfo(
1617+
fully_qualified_name=fqn, location=DataLocation.EXTERNAL
1618+
)
1619+
else:
1620+
spec.extra_tensor_info.fully_qualified_name = fqn
1621+
spec.extra_tensor_info.location = DataLocation.EXTERNAL
1622+
15651623
# From the fqn find the corresponding tensor
15661624
real_tensor = None
15671625
if fqn in self.exported_program.state_dict:
@@ -1599,7 +1657,7 @@ def placeholder(
15991657
spec.const = not (is_user_input or is_mutable_buffer)
16001658

16011659
evalue = (
1602-
self._tensor_spec_to_evalue(spec)
1660+
self._tensor_spec_to_evalue(spec, constant_tag)
16031661
if isinstance(spec, TensorSpec)
16041662
else self._constant_to_evalue(spec, None)
16051663
)

0 commit comments

Comments
 (0)