Skip to content

Commit 147b4b2

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Add pass to extract mutable weights into a .ptd
Summary: Cleaned up the existing pass and fixed a typing error (EP -> PassResult), added another option in backend config to extract only mutable weights (training workflows will do this), fixed the ordering of ET passes and added a warning not to add stuff after memory planning (this pass was actually fine but in general we like having the invariant that memory planning is last), fixed the emitter to prioritize making it external vs mutable. Down the line we will need to support intermixing of mutable and non mutable in the same .ptd (memory regressions not correctness are the stakes), but no one needs that today so deferring. Reviewed By: lucylq Differential Revision: D68121580
1 parent 948fba6 commit 147b4b2

File tree

5 files changed

+132
-37
lines changed

5 files changed

+132
-37
lines changed

exir/capture/_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,7 @@ class ExecutorchBackendConfig:
9292
# If set to true, all constant tensors will be stored in a separate file,
9393
# external to the PTE file.
9494
external_constants: bool = False
95+
96+
# If set to true, all trainable weights will be stored in a separate file,
97+
# external to the PTE file.
98+
external_mutable_weights: bool = False

exir/emit/_emitter.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -387,38 +387,36 @@ def _save_new_const_tensor(
387387
# Update buffer_idx to point to the end of the list where we are adding the new buffer.
388388
buffer = Buffer(storage=buffer_data)
389389

390-
# Tensor is mutable with initial state.
391-
if allocation_info:
390+
# Tensor is stored outside of the PTE file.
391+
if (
392+
spec.extra_tensor_info is not None
393+
and spec.extra_tensor_info.fully_qualified_name is not None
394+
and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL
395+
):
396+
assert (
397+
constant_tag is not None
398+
), "Constant tag is not set for external tensor"
399+
# TODO (#7633): Handle case where we have both mutable and non mutable weights that we want to put in the same external file.
400+
# We will need to create 2 segments in that case, but it'll be a bit until we see this case. LLM finetuning will probably require this.
401+
402+
buffer_idx = len(self.program_state.external_constant_buffer)
403+
self.program_state.external_constant_hash[hashed] = buffer_idx
404+
self.program_state.external_constant_buffer.append(buffer_data)
405+
if constant_tag not in self.program_state.external_constant_map:
406+
self.program_state.external_constant_map[constant_tag] = {}
407+
self.program_state.external_constant_map[constant_tag][
408+
spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
409+
] = buffer_idx
410+
# Tensor is mutable with initial state. Place into mutable segment
411+
elif allocation_info:
392412
buffer_idx = len(self.program_state.mutable_buffer)
393413
self.program_state.cached_spec_mutable_hash_values[hashed] = buffer_idx
394414
self.program_state.mutable_buffer.append(buffer)
395-
396-
# Tensor is constant.
415+
# Tensor is stored in the PTE file.
397416
else:
398-
# Tensor is stored outside of the PTE file.
399-
if (
400-
spec.extra_tensor_info is not None
401-
and spec.extra_tensor_info.fully_qualified_name is not None
402-
and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL
403-
):
404-
assert (
405-
constant_tag is not None
406-
), "Constant tag is not set for external tensor"
407-
408-
buffer_idx = len(self.program_state.external_constant_buffer)
409-
self.program_state.external_constant_hash[hashed] = buffer_idx
410-
self.program_state.external_constant_buffer.append(buffer_data)
411-
if constant_tag not in self.program_state.external_constant_map:
412-
self.program_state.external_constant_map[constant_tag] = {}
413-
self.program_state.external_constant_map[constant_tag][
414-
spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
415-
] = buffer_idx
416-
417-
# Tensor is stored in the PTE file.
418-
else:
419-
buffer_idx = len(self.program_state.constant_buffer)
420-
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
421-
self.program_state.constant_buffer.append(buffer)
417+
buffer_idx = len(self.program_state.constant_buffer)
418+
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
419+
self.program_state.constant_buffer.append(buffer)
422420

423421
return buffer_idx
424422

@@ -458,7 +456,7 @@ def _tensor_spec_to_evalue(
458456

459457
hashed = hashlib.sha256(buffer_data).hexdigest()
460458

461-
if allocation_info:
459+
if allocation_info and spec.extra_tensor_info is None:
462460
buffer_idx = self.program_state.cached_spec_mutable_hash_values.get(
463461
hashed, -1
464462
)

exir/emit/test/test_emit.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@
6767
from torch import nn
6868

6969
from torch.export import Dim, export, export_for_training
70+
from torch.export import Dim, export
71+
from torch.export.experimental import _export_forward_backward
72+
from torch.nn import functional as F
7073

7174

7275
class WrapperModule(torch.nn.Module):
@@ -1733,3 +1736,49 @@ def forward(self, x):
17331736
self.assertEqual(
17341737
len(edge_program_manager.executorch_program.backend_delegate_data), 1
17351738
)
1739+
def test_constant_tagged_mutable_tensors(self) -> None:
1740+
class Net(nn.Module):
1741+
def __init__(self):
1742+
super().__init__()
1743+
self.linear = nn.Linear(2, 2)
1744+
1745+
def forward(self, x):
1746+
return self.linear(x)
1747+
1748+
1749+
# On device training requires the loss to be embedded in the model (and be the first output).
1750+
# We wrap the original model here and add the loss calculation. This will be the model we export.
1751+
class TrainingNet(nn.Module):
1752+
def __init__(self, net):
1753+
super().__init__()
1754+
self.net = net
1755+
self.loss = nn.CrossEntropyLoss()
1756+
1757+
def forward(self, input, label):
1758+
pred = self.net(input)
1759+
return self.loss(pred, label), pred.detach().argmax(dim=1)
1760+
1761+
net = TrainingNet(Net())
1762+
1763+
# Captures the forward graph. The graph will look similar to the model definition now.
1764+
# Will move to export_for_training soon which is the api planned to be supported in the long term.
1765+
ep = export(net, (torch.randn(1, 2), torch.ones(1, dtype=torch.int64)), strict=True)
1766+
# Captures the backward graph. The exported_program now contains the joint forward and backward graph.
1767+
ep = _export_forward_backward(ep)
1768+
# Lower the graph to edge dialect.
1769+
ep = to_edge(ep)
1770+
# Lower the graph to executorch.
1771+
ep = ep.to_executorch(config=ExecutorchBackendConfig(external_mutable_weights=True))
1772+
1773+
emitter_output = ep._emitter_output
1774+
# Check that constant_buffer is empty besides the non-constant placeholder 0.
1775+
self.assertEqual(len(emitter_output.program.constant_buffer), 1)
1776+
# Check that constant weights are in the external constant buffer.
1777+
self.assertEqual(len(emitter_output.external_constant_buffer), 2)
1778+
# Setting external_mutable_weights=True, saves all constants with an associated gradient to the key
1779+
# '_default_external_constant'.
1780+
external_map = emitter_output.external_constant_map[
1781+
"_default_external_constant"
1782+
]
1783+
self.assertEqual(external_map["net.linear.weight"], 0)
1784+
self.assertEqual(external_map["net.linear.bias"], 1)

exir/passes/external_constants_pass.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,20 @@
77
# pyre-strict
88

99
import torch
10+
from executorch.exir.pass_base import PassResult
1011
from executorch.exir.tensor import TensorSpec
11-
from torch.export.exported_program import ExportedProgram
12+
from torch.export.exported_program import ExportedProgram, OutputKind
13+
from torch.fx import GraphModule
1214

1315

1416
def external_constants_pass(
15-
ep: ExportedProgram,
16-
) -> ExportedProgram:
17+
gm: GraphModule,
18+
) -> PassResult:
1719
"""
1820
Move all constants to external file.
1921
"""
20-
for module in ep.graph_module.modules():
22+
mutated = False
23+
for module in gm.modules():
2124
if not isinstance(module, torch.fx.GraphModule):
2225
continue
2326

@@ -26,4 +29,37 @@ def external_constants_pass(
2629
spec = node.meta.get("spec")
2730
if isinstance(spec, TensorSpec) and spec.const:
2831
node.meta["constant_tag"] = "_default_external_constant"
29-
return ep
32+
mutated = True
33+
return PassResult(gm, mutated)
34+
35+
def _is_mutable_weight(node: torch.fx.Node, ep: ExportedProgram) -> bool:
36+
grad_targets = [
37+
spec.target
38+
for spec in ep.graph_signature.output_specs
39+
if spec.kind == OutputKind.GRADIENT_TO_PARAMETER
40+
]
41+
return node.op == "placeholder" \
42+
and node.target in ep.graph_signature.inputs_to_parameters.keys() \
43+
and ep.graph_signature.inputs_to_parameters[node.target] in grad_targets
44+
45+
def external_mutable_weights_pass(
46+
gm: GraphModule, ep: ExportedProgram,
47+
) -> PassResult:
48+
"""
49+
Move all mutable weights to external file.
50+
"""
51+
# pass the gm and the ep seperately as the gm is being mutated by a bunch of passes in to_executorch,
52+
# so the gm in the ep is lagging the graph signature is still correct.
53+
# This is really tech debt and all the passes should be refactored to just mutate the ep.
54+
mutated = False
55+
for module in gm.modules():
56+
if not isinstance(module, torch.fx.GraphModule):
57+
continue
58+
59+
for node in module.graph.nodes:
60+
if node.op == "placeholder":
61+
spec = node.meta.get("spec")
62+
if isinstance(spec, TensorSpec) and spec.const and _is_mutable_weight(node, ep):
63+
node.meta["constant_tag"] = "_default_external_constant"
64+
mutated = True
65+
return PassResult(gm, mutated)

exir/program/_program.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
MemoryFormatOpsPass,
3636
OpReplacePass,
3737
)
38-
from executorch.exir.passes.external_constants_pass import external_constants_pass
38+
from executorch.exir.passes.external_constants_pass import external_constants_pass, external_mutable_weights_pass
3939
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
4040
insert_write_back_for_buffers_pass,
4141
)
@@ -1394,6 +1394,14 @@ def to_executorch(
13941394
# in the ExportedProgram
13951395
# TODO(who?)
13961396
p.update_placeholder_tensor_specs(program, new_gm)
1397+
1398+
# Extract constants if the config says too.
1399+
if config.external_constants:
1400+
new_gm_res = external_constants_pass(new_gm)
1401+
new_gm = new_gm_res.graph_module
1402+
elif config.external_mutable_weights:
1403+
new_gm_res = external_mutable_weights_pass(new_gm, program)
1404+
new_gm = new_gm_res.graph_module
13971405

13981406
if isinstance(config.memory_planning_pass, dict):
13991407
memory_planning_pass = config.memory_planning_pass.get(
@@ -1409,8 +1417,8 @@ def to_executorch(
14091417
else:
14101418
new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29]
14111419

1412-
if config.external_constants:
1413-
new_gm_res = external_constants_pass(new_gm_res)
1420+
# WARNING: DO NOT ADD ANY MORE PASSES AFTER MEMORY PLANNING PASS.
1421+
# THERE ARE A LOT OF ASSUMPTIONS IN THE STACK THAT MEMORY PLANNING IS THE LAST PASS BEFORE THE EMITTER.
14141422
assert new_gm_res is not None
14151423
new_gm = new_gm_res.graph_module
14161424

0 commit comments

Comments
 (0)