Skip to content

Commit 4655202

Browse files
Add pass to extract mutable weights into a .ptd
Differential Revision: D68121580 Pull Request resolved: #7798
1 parent e342093 commit 4655202

File tree

5 files changed

+146
-37
lines changed

5 files changed

+146
-37
lines changed

exir/capture/_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,7 @@ class ExecutorchBackendConfig:
9292
# If set to true, all constant tensors will be stored in a separate file,
9393
# external to the PTE file.
9494
external_constants: bool = False
95+
96+
# If set to true, all trainable weights will be stored in a separate file,
97+
# external to the PTE file.
98+
external_mutable_weights: bool = False

exir/emit/_emitter.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -387,38 +387,36 @@ def _save_new_const_tensor(
387387
# Update buffer_idx to point to the end of the list where we are adding the new buffer.
388388
buffer = Buffer(storage=buffer_data)
389389

390-
# Tensor is mutable with initial state.
391-
if allocation_info:
390+
# Tensor is stored outside of the PTE file.
391+
if (
392+
spec.extra_tensor_info is not None
393+
and spec.extra_tensor_info.fully_qualified_name is not None
394+
and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL
395+
):
396+
assert (
397+
constant_tag is not None
398+
), "Constant tag is not set for external tensor"
399+
# TODO (#7633): Handle case where we have both mutable and non mutable weights that we want to put in the same external file.
400+
# We will need to create 2 segments in that case, but it'll be a bit until we see this case. LLM finetuning will probably require this.
401+
402+
buffer_idx = len(self.program_state.external_constant_buffer)
403+
self.program_state.external_constant_hash[hashed] = buffer_idx
404+
self.program_state.external_constant_buffer.append(buffer_data)
405+
if constant_tag not in self.program_state.external_constant_map:
406+
self.program_state.external_constant_map[constant_tag] = {}
407+
self.program_state.external_constant_map[constant_tag][
408+
spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
409+
] = buffer_idx
410+
# Tensor is mutable with initial state. Place into mutable segment
411+
elif allocation_info:
392412
buffer_idx = len(self.program_state.mutable_buffer)
393413
self.program_state.cached_spec_mutable_hash_values[hashed] = buffer_idx
394414
self.program_state.mutable_buffer.append(buffer)
395-
396-
# Tensor is constant.
415+
# Tensor is stored in the PTE file.
397416
else:
398-
# Tensor is stored outside of the PTE file.
399-
if (
400-
spec.extra_tensor_info is not None
401-
and spec.extra_tensor_info.fully_qualified_name is not None
402-
and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL
403-
):
404-
assert (
405-
constant_tag is not None
406-
), "Constant tag is not set for external tensor"
407-
408-
buffer_idx = len(self.program_state.external_constant_buffer)
409-
self.program_state.external_constant_hash[hashed] = buffer_idx
410-
self.program_state.external_constant_buffer.append(buffer_data)
411-
if constant_tag not in self.program_state.external_constant_map:
412-
self.program_state.external_constant_map[constant_tag] = {}
413-
self.program_state.external_constant_map[constant_tag][
414-
spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
415-
] = buffer_idx
416-
417-
# Tensor is stored in the PTE file.
418-
else:
419-
buffer_idx = len(self.program_state.constant_buffer)
420-
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
421-
self.program_state.constant_buffer.append(buffer)
417+
buffer_idx = len(self.program_state.constant_buffer)
418+
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
419+
self.program_state.constant_buffer.append(buffer)
422420

423421
return buffer_idx
424422

@@ -458,7 +456,7 @@ def _tensor_spec_to_evalue(
458456

459457
hashed = hashlib.sha256(buffer_data).hexdigest()
460458

461-
if allocation_info:
459+
if allocation_info and spec.extra_tensor_info is None:
462460
buffer_idx = self.program_state.cached_spec_mutable_hash_values.get(
463461
hashed, -1
464462
)

exir/emit/test/test_emit.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from torch import nn
6868

6969
from torch.export import Dim, export, export_for_training
70+
from torch.export.experimental import _export_forward_backward
7071

7172

7273
class WrapperModule(torch.nn.Module):
@@ -1733,3 +1734,53 @@ def forward(self, x):
17331734
self.assertEqual(
17341735
len(edge_program_manager.executorch_program.backend_delegate_data), 1
17351736
)
1737+
1738+
def test_constant_tagged_mutable_tensors(self) -> None:
1739+
class Net(nn.Module):
1740+
def __init__(self):
1741+
super().__init__()
1742+
self.linear = nn.Linear(2, 2)
1743+
1744+
def forward(self, x):
1745+
return self.linear(x)
1746+
1747+
# On device training requires the loss to be embedded in the model (and be the first output).
1748+
# We wrap the original model here and add the loss calculation. This will be the model we export.
1749+
class TrainingNet(nn.Module):
1750+
def __init__(self, net):
1751+
super().__init__()
1752+
self.net = net
1753+
self.loss = nn.CrossEntropyLoss()
1754+
1755+
def forward(self, input, label):
1756+
pred = self.net(input)
1757+
return self.loss(pred, label), pred.detach().argmax(dim=1)
1758+
1759+
net = TrainingNet(Net())
1760+
1761+
# Captures the forward graph. The graph will look similar to the model definition now.
1762+
# Will move to export_for_training soon which is the api planned to be supported in the long term.
1763+
ep = export(
1764+
net, (torch.randn(1, 2), torch.ones(1, dtype=torch.int64)), strict=True
1765+
)
1766+
# Captures the backward graph. The exported_program now contains the joint forward and backward graph.
1767+
ep = _export_forward_backward(ep)
1768+
# Lower the graph to edge dialect.
1769+
ep = to_edge(ep)
1770+
# Lower the graph to executorch.
1771+
ep = ep.to_executorch(
1772+
config=ExecutorchBackendConfig(external_mutable_weights=True)
1773+
)
1774+
1775+
emitter_output = ep._emitter_output
1776+
# Check that constant_buffer is empty besides the non-constant placeholder 0.
1777+
self.assertEqual(len(emitter_output.program.constant_buffer), 1)
1778+
# Check that constant weights are in the external constant buffer.
1779+
self.assertEqual(len(emitter_output.external_constant_buffer), 2)
1780+
# Setting external_mutable_weights=True, saves all constants with an associated gradient to the key
1781+
# '_default_external_constant'.
1782+
external_map = emitter_output.external_constant_map[
1783+
"_default_external_constant"
1784+
]
1785+
self.assertEqual(external_map["net.linear.weight"], 0)
1786+
self.assertEqual(external_map["net.linear.bias"], 1)

exir/passes/external_constants_pass.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,20 @@
77
# pyre-strict
88

99
import torch
10+
from executorch.exir.pass_base import PassResult
1011
from executorch.exir.tensor import TensorSpec
11-
from torch.export.exported_program import ExportedProgram
12+
from torch.export.exported_program import ExportedProgram, OutputKind
13+
from torch.fx import GraphModule
1214

1315

1416
def external_constants_pass(
15-
ep: ExportedProgram,
16-
) -> ExportedProgram:
17+
gm: GraphModule,
18+
) -> PassResult:
1719
"""
1820
Move all constants to external file.
1921
"""
20-
for module in ep.graph_module.modules():
22+
mutated = False
23+
for module in gm.modules():
2124
if not isinstance(module, torch.fx.GraphModule):
2225
continue
2326

@@ -26,4 +29,46 @@ def external_constants_pass(
2629
spec = node.meta.get("spec")
2730
if isinstance(spec, TensorSpec) and spec.const:
2831
node.meta["constant_tag"] = "_default_external_constant"
29-
return ep
32+
mutated = True
33+
return PassResult(gm, mutated)
34+
35+
36+
def _is_mutable_weight(node: torch.fx.Node, ep: ExportedProgram) -> bool:
37+
grad_targets = [
38+
spec.target
39+
for spec in ep.graph_signature.output_specs
40+
if spec.kind == OutputKind.GRADIENT_TO_PARAMETER
41+
]
42+
return (
43+
node.op == "placeholder"
44+
and node.target in ep.graph_signature.inputs_to_parameters.keys()
45+
and ep.graph_signature.inputs_to_parameters[node.target] in grad_targets
46+
)
47+
48+
49+
def external_mutable_weights_pass(
50+
gm: GraphModule,
51+
ep: ExportedProgram,
52+
) -> PassResult:
53+
"""
54+
Move all mutable weights to external file.
55+
"""
56+
# pass the gm and the ep seperately as the gm is being mutated by a bunch of passes in to_executorch,
57+
# so the gm in the ep is lagging the graph signature is still correct.
58+
# This is really tech debt and all the passes should be refactored to just mutate the ep.
59+
mutated = False
60+
for module in gm.modules():
61+
if not isinstance(module, torch.fx.GraphModule):
62+
continue
63+
64+
for node in module.graph.nodes:
65+
if node.op == "placeholder":
66+
spec = node.meta.get("spec")
67+
if (
68+
isinstance(spec, TensorSpec)
69+
and spec.const
70+
and _is_mutable_weight(node, ep)
71+
):
72+
node.meta["constant_tag"] = "_default_external_constant"
73+
mutated = True
74+
return PassResult(gm, mutated)

exir/program/_program.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
MemoryFormatOpsPass,
3636
OpReplacePass,
3737
)
38-
from executorch.exir.passes.external_constants_pass import external_constants_pass
38+
from executorch.exir.passes.external_constants_pass import (
39+
external_constants_pass,
40+
external_mutable_weights_pass,
41+
)
3942
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
4043
insert_write_back_for_buffers_pass,
4144
)
@@ -1395,6 +1398,14 @@ def to_executorch(
13951398
# TODO(who?)
13961399
p.update_placeholder_tensor_specs(program, new_gm)
13971400

1401+
# Extract constants if the config says too.
1402+
if config.external_constants:
1403+
new_gm_res = external_constants_pass(new_gm)
1404+
new_gm = new_gm_res.graph_module
1405+
elif config.external_mutable_weights:
1406+
new_gm_res = external_mutable_weights_pass(new_gm, program)
1407+
new_gm = new_gm_res.graph_module
1408+
13981409
if isinstance(config.memory_planning_pass, dict):
13991410
memory_planning_pass = config.memory_planning_pass.get(
14001411
name, ExecutorchBackendConfig().memory_planning_pass
@@ -1409,8 +1420,8 @@ def to_executorch(
14091420
else:
14101421
new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29]
14111422

1412-
if config.external_constants:
1413-
new_gm_res = external_constants_pass(new_gm_res)
1423+
# WARNING: DO NOT ADD ANY MORE PASSES AFTER MEMORY PLANNING PASS.
1424+
# THERE ARE A LOT OF ASSUMPTIONS IN THE STACK THAT MEMORY PLANNING IS THE LAST PASS BEFORE THE EMITTER.
14141425
assert new_gm_res is not None
14151426
new_gm = new_gm_res.graph_module
14161427

0 commit comments

Comments
 (0)