Skip to content

Commit b67b193

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Add pass to extract mutable weights into a .ptd (#7798)
Summary: Cleaned up the existing pass and fixed a typing error (EP -> PassResult), added another option in backend config to extract only mutable weights (training workflows will do this), fixed the ordering of ET passes and added a warning not to add stuff after memory planning (this pass was actually fine but in general we like having the invariant that memory planning is last), fixed the emitter to prioritize making it external vs mutable. Down the line we will need to support intermixing of mutable and non mutable in the same .ptd (memory regressions not correctness are the stakes), but no one needs that today so deferring. Reviewed By: lucylq Differential Revision: D68121580
1 parent 948fba6 commit b67b193

File tree

5 files changed

+146
-37
lines changed

5 files changed

+146
-37
lines changed

exir/capture/_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,7 @@ class ExecutorchBackendConfig:
9292
# If set to true, all constant tensors will be stored in a separate file,
9393
# external to the PTE file.
9494
external_constants: bool = False
95+
96+
# If set to true, all trainable weights will be stored in a separate file,
97+
# external to the PTE file.
98+
external_mutable_weights: bool = False

exir/emit/_emitter.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -387,38 +387,36 @@ def _save_new_const_tensor(
387387
# Update buffer_idx to point to the end of the list where we are adding the new buffer.
388388
buffer = Buffer(storage=buffer_data)
389389

390-
# Tensor is mutable with initial state.
391-
if allocation_info:
390+
# Tensor is stored outside of the PTE file.
391+
if (
392+
spec.extra_tensor_info is not None
393+
and spec.extra_tensor_info.fully_qualified_name is not None
394+
and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL
395+
):
396+
assert (
397+
constant_tag is not None
398+
), "Constant tag is not set for external tensor"
399+
# TODO (#7633): Handle case where we have both mutable and non mutable weights that we want to put in the same external file.
400+
# We will need to create 2 segments in that case, but it'll be a bit until we see this case. LLM finetuning will probably require this.
401+
402+
buffer_idx = len(self.program_state.external_constant_buffer)
403+
self.program_state.external_constant_hash[hashed] = buffer_idx
404+
self.program_state.external_constant_buffer.append(buffer_data)
405+
if constant_tag not in self.program_state.external_constant_map:
406+
self.program_state.external_constant_map[constant_tag] = {}
407+
self.program_state.external_constant_map[constant_tag][
408+
spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
409+
] = buffer_idx
410+
# Tensor is mutable with initial state. Place into mutable segment
411+
elif allocation_info:
392412
buffer_idx = len(self.program_state.mutable_buffer)
393413
self.program_state.cached_spec_mutable_hash_values[hashed] = buffer_idx
394414
self.program_state.mutable_buffer.append(buffer)
395-
396-
# Tensor is constant.
415+
# Tensor is stored in the PTE file.
397416
else:
398-
# Tensor is stored outside of the PTE file.
399-
if (
400-
spec.extra_tensor_info is not None
401-
and spec.extra_tensor_info.fully_qualified_name is not None
402-
and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL
403-
):
404-
assert (
405-
constant_tag is not None
406-
), "Constant tag is not set for external tensor"
407-
408-
buffer_idx = len(self.program_state.external_constant_buffer)
409-
self.program_state.external_constant_hash[hashed] = buffer_idx
410-
self.program_state.external_constant_buffer.append(buffer_data)
411-
if constant_tag not in self.program_state.external_constant_map:
412-
self.program_state.external_constant_map[constant_tag] = {}
413-
self.program_state.external_constant_map[constant_tag][
414-
spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
415-
] = buffer_idx
416-
417-
# Tensor is stored in the PTE file.
418-
else:
419-
buffer_idx = len(self.program_state.constant_buffer)
420-
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
421-
self.program_state.constant_buffer.append(buffer)
417+
buffer_idx = len(self.program_state.constant_buffer)
418+
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
419+
self.program_state.constant_buffer.append(buffer)
422420

423421
return buffer_idx
424422

@@ -458,7 +456,7 @@ def _tensor_spec_to_evalue(
458456

459457
hashed = hashlib.sha256(buffer_data).hexdigest()
460458

461-
if allocation_info:
459+
if allocation_info and spec.extra_tensor_info is None:
462460
buffer_idx = self.program_state.cached_spec_mutable_hash_values.get(
463461
hashed, -1
464462
)

exir/emit/test/test_emit.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from torch import nn
6868

6969
from torch.export import Dim, export, export_for_training
70+
from torch.export.experimental import _export_forward_backward
7071

7172

7273
class WrapperModule(torch.nn.Module):
@@ -1733,3 +1734,53 @@ def forward(self, x):
17331734
self.assertEqual(
17341735
len(edge_program_manager.executorch_program.backend_delegate_data), 1
17351736
)
1737+
1738+
def test_constant_tagged_mutable_tensors(self) -> None:
1739+
class Net(nn.Module):
1740+
def __init__(self):
1741+
super().__init__()
1742+
self.linear = nn.Linear(2, 2)
1743+
1744+
def forward(self, x):
1745+
return self.linear(x)
1746+
1747+
# On device training requires the loss to be embedded in the model (and be the first output).
1748+
# We wrap the original model here and add the loss calculation. This will be the model we export.
1749+
class TrainingNet(nn.Module):
1750+
def __init__(self, net):
1751+
super().__init__()
1752+
self.net = net
1753+
self.loss = nn.CrossEntropyLoss()
1754+
1755+
def forward(self, input, label):
1756+
pred = self.net(input)
1757+
return self.loss(pred, label), pred.detach().argmax(dim=1)
1758+
1759+
net = TrainingNet(Net())
1760+
1761+
# Captures the forward graph. The graph will look similar to the model definition now.
1762+
# Will move to export_for_training soon which is the api planned to be supported in the long term.
1763+
ep = export(
1764+
net, (torch.randn(1, 2), torch.ones(1, dtype=torch.int64)), strict=True
1765+
)
1766+
# Captures the backward graph. The exported_program now contains the joint forward and backward graph.
1767+
ep = _export_forward_backward(ep)
1768+
# Lower the graph to edge dialect.
1769+
ep = to_edge(ep)
1770+
# Lower the graph to executorch.
1771+
ep = ep.to_executorch(
1772+
config=ExecutorchBackendConfig(external_mutable_weights=True)
1773+
)
1774+
1775+
emitter_output = ep._emitter_output
1776+
# Check that constant_buffer is empty besides the non-constant placeholder 0.
1777+
self.assertEqual(len(emitter_output.program.constant_buffer), 1)
1778+
# Check that constant weights are in the external constant buffer.
1779+
self.assertEqual(len(emitter_output.external_constant_buffer), 2)
1780+
# Setting external_mutable_weights=True, saves all constants with an associated gradient to the key
1781+
# '_default_external_constant'.
1782+
external_map = emitter_output.external_constant_map[
1783+
"_default_external_constant"
1784+
]
1785+
self.assertEqual(external_map["net.linear.weight"], 0)
1786+
self.assertEqual(external_map["net.linear.bias"], 1)

exir/passes/external_constants_pass.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,20 @@
77
# pyre-strict
88

99
import torch
10+
from executorch.exir.pass_base import PassResult
1011
from executorch.exir.tensor import TensorSpec
11-
from torch.export.exported_program import ExportedProgram
12+
from torch.export.exported_program import ExportedProgram, OutputKind
13+
from torch.fx import GraphModule
1214

1315

1416
def external_constants_pass(
15-
ep: ExportedProgram,
16-
) -> ExportedProgram:
17+
gm: GraphModule,
18+
) -> PassResult:
1719
"""
1820
Move all constants to external file.
1921
"""
20-
for module in ep.graph_module.modules():
22+
mutated = False
23+
for module in gm.modules():
2124
if not isinstance(module, torch.fx.GraphModule):
2225
continue
2326

@@ -26,4 +29,46 @@ def external_constants_pass(
2629
spec = node.meta.get("spec")
2730
if isinstance(spec, TensorSpec) and spec.const:
2831
node.meta["constant_tag"] = "_default_external_constant"
29-
return ep
32+
mutated = True
33+
return PassResult(gm, mutated)
34+
35+
36+
def _is_mutable_weight(node: torch.fx.Node, ep: ExportedProgram) -> bool:
37+
grad_targets = [
38+
spec.target
39+
for spec in ep.graph_signature.output_specs
40+
if spec.kind == OutputKind.GRADIENT_TO_PARAMETER
41+
]
42+
return (
43+
node.op == "placeholder"
44+
and node.target in ep.graph_signature.inputs_to_parameters.keys()
45+
and ep.graph_signature.inputs_to_parameters[node.target] in grad_targets
46+
)
47+
48+
49+
def external_mutable_weights_pass(
50+
gm: GraphModule,
51+
ep: ExportedProgram,
52+
) -> PassResult:
53+
"""
54+
Move all mutable weights to external file.
55+
"""
56+
# pass the gm and the ep seperately as the gm is being mutated by a bunch of passes in to_executorch,
57+
# so the gm in the ep is lagging the graph signature is still correct.
58+
# This is really tech debt and all the passes should be refactored to just mutate the ep.
59+
mutated = False
60+
for module in gm.modules():
61+
if not isinstance(module, torch.fx.GraphModule):
62+
continue
63+
64+
for node in module.graph.nodes:
65+
if node.op == "placeholder":
66+
spec = node.meta.get("spec")
67+
if (
68+
isinstance(spec, TensorSpec)
69+
and spec.const
70+
and _is_mutable_weight(node, ep)
71+
):
72+
node.meta["constant_tag"] = "_default_external_constant"
73+
mutated = True
74+
return PassResult(gm, mutated)

exir/program/_program.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
MemoryFormatOpsPass,
3636
OpReplacePass,
3737
)
38-
from executorch.exir.passes.external_constants_pass import external_constants_pass
38+
from executorch.exir.passes.external_constants_pass import (
39+
external_constants_pass,
40+
external_mutable_weights_pass,
41+
)
3942
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
4043
insert_write_back_for_buffers_pass,
4144
)
@@ -1395,6 +1398,14 @@ def to_executorch(
13951398
# TODO(who?)
13961399
p.update_placeholder_tensor_specs(program, new_gm)
13971400

1401+
# Extract constants if the config says too.
1402+
if config.external_constants:
1403+
new_gm_res = external_constants_pass(new_gm)
1404+
new_gm = new_gm_res.graph_module
1405+
elif config.external_mutable_weights:
1406+
new_gm_res = external_mutable_weights_pass(new_gm, program)
1407+
new_gm = new_gm_res.graph_module
1408+
13981409
if isinstance(config.memory_planning_pass, dict):
13991410
memory_planning_pass = config.memory_planning_pass.get(
14001411
name, ExecutorchBackendConfig().memory_planning_pass
@@ -1409,8 +1420,8 @@ def to_executorch(
14091420
else:
14101421
new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29]
14111422

1412-
if config.external_constants:
1413-
new_gm_res = external_constants_pass(new_gm_res)
1423+
# WARNING: DO NOT ADD ANY MORE PASSES AFTER MEMORY PLANNING PASS.
1424+
# THERE ARE A LOT OF ASSUMPTIONS IN THE STACK THAT MEMORY PLANNING IS THE LAST PASS BEFORE THE EMITTER.
14141425
assert new_gm_res is not None
14151426
new_gm = new_gm_res.graph_module
14161427

0 commit comments

Comments
 (0)