Skip to content

Add pass to extract mutable weights into a .ptd #7798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,7 @@ class ExecutorchBackendConfig:
# If set to true, all constant tensors will be stored in a separate file,
# external to the PTE file.
external_constants: bool = False

# If set to true, all trainable weights will be stored in a separate file,
# external to the PTE file.
external_mutable_weights: bool = False
56 changes: 27 additions & 29 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,38 +387,36 @@ def _save_new_const_tensor(
# Update buffer_idx to point to the end of the list where we are adding the new buffer.
buffer = Buffer(storage=buffer_data)

# Tensor is mutable with initial state.
if allocation_info:
# Tensor is stored outside of the PTE file.
if (
spec.extra_tensor_info is not None
and spec.extra_tensor_info.fully_qualified_name is not None
and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL
):
assert (
constant_tag is not None
), "Constant tag is not set for external tensor"
# TODO (#7633): Handle case where we have both mutable and non mutable weights that we want to put in the same external file.
# We will need to create 2 segments in that case, but it'll be a bit until we see this case. LLM finetuning will probably require this.

buffer_idx = len(self.program_state.external_constant_buffer)
self.program_state.external_constant_hash[hashed] = buffer_idx
self.program_state.external_constant_buffer.append(buffer_data)
if constant_tag not in self.program_state.external_constant_map:
self.program_state.external_constant_map[constant_tag] = {}
self.program_state.external_constant_map[constant_tag][
spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
] = buffer_idx
# Tensor is mutable with initial state. Place into mutable segment
elif allocation_info:
buffer_idx = len(self.program_state.mutable_buffer)
self.program_state.cached_spec_mutable_hash_values[hashed] = buffer_idx
self.program_state.mutable_buffer.append(buffer)

# Tensor is constant.
# Tensor is stored in the PTE file.
else:
# Tensor is stored outside of the PTE file.
if (
spec.extra_tensor_info is not None
and spec.extra_tensor_info.fully_qualified_name is not None
and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL
):
assert (
constant_tag is not None
), "Constant tag is not set for external tensor"

buffer_idx = len(self.program_state.external_constant_buffer)
self.program_state.external_constant_hash[hashed] = buffer_idx
self.program_state.external_constant_buffer.append(buffer_data)
if constant_tag not in self.program_state.external_constant_map:
self.program_state.external_constant_map[constant_tag] = {}
self.program_state.external_constant_map[constant_tag][
spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
] = buffer_idx

# Tensor is stored in the PTE file.
else:
buffer_idx = len(self.program_state.constant_buffer)
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
self.program_state.constant_buffer.append(buffer)
buffer_idx = len(self.program_state.constant_buffer)
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
self.program_state.constant_buffer.append(buffer)

return buffer_idx

Expand Down Expand Up @@ -458,7 +456,7 @@ def _tensor_spec_to_evalue(

hashed = hashlib.sha256(buffer_data).hexdigest()

if allocation_info:
if allocation_info and spec.extra_tensor_info is None:
buffer_idx = self.program_state.cached_spec_mutable_hash_values.get(
hashed, -1
)
Expand Down
51 changes: 51 additions & 0 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from torch import nn

from torch.export import Dim, export, export_for_training
from torch.export.experimental import _export_forward_backward


class WrapperModule(torch.nn.Module):
Expand Down Expand Up @@ -1733,3 +1734,53 @@ def forward(self, x):
self.assertEqual(
len(edge_program_manager.executorch_program.backend_delegate_data), 1
)

def test_constant_tagged_mutable_tensors(self) -> None:
class Net(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 2)

def forward(self, x):
return self.linear(x)

# On device training requires the loss to be embedded in the model (and be the first output).
# We wrap the original model here and add the loss calculation. This will be the model we export.
class TrainingNet(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.loss = nn.CrossEntropyLoss()

def forward(self, input, label):
pred = self.net(input)
return self.loss(pred, label), pred.detach().argmax(dim=1)

net = TrainingNet(Net())

# Captures the forward graph. The graph will look similar to the model definition now.
# Will move to export_for_training soon which is the api planned to be supported in the long term.
ep = export(
net, (torch.randn(1, 2), torch.ones(1, dtype=torch.int64)), strict=True
)
# Captures the backward graph. The exported_program now contains the joint forward and backward graph.
ep = _export_forward_backward(ep)
# Lower the graph to edge dialect.
ep = to_edge(ep)
# Lower the graph to executorch.
ep = ep.to_executorch(
config=ExecutorchBackendConfig(external_mutable_weights=True)
)

emitter_output = ep._emitter_output
# Check that constant_buffer is empty besides the non-constant placeholder 0.
self.assertEqual(len(emitter_output.program.constant_buffer), 1)
# Check that constant weights are in the external constant buffer.
self.assertEqual(len(emitter_output.external_constant_buffer), 2)
# Setting external_mutable_weights=True, saves all constants with an associated gradient to the key
# '_default_external_constant'.
external_map = emitter_output.external_constant_map[
"_default_external_constant"
]
self.assertEqual(external_map["net.linear.weight"], 0)
self.assertEqual(external_map["net.linear.bias"], 1)
55 changes: 50 additions & 5 deletions exir/passes/external_constants_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@
# pyre-strict

import torch
from executorch.exir.pass_base import PassResult
from executorch.exir.tensor import TensorSpec
from torch.export.exported_program import ExportedProgram
from torch.export.exported_program import ExportedProgram, OutputKind
from torch.fx import GraphModule


def external_constants_pass(
ep: ExportedProgram,
) -> ExportedProgram:
gm: GraphModule,
) -> PassResult:
"""
Move all constants to external file.
"""
for module in ep.graph_module.modules():
mutated = False
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
continue

Expand All @@ -26,4 +29,46 @@ def external_constants_pass(
spec = node.meta.get("spec")
if isinstance(spec, TensorSpec) and spec.const:
node.meta["constant_tag"] = "_default_external_constant"
return ep
mutated = True
return PassResult(gm, mutated)


def _is_mutable_weight(node: torch.fx.Node, ep: ExportedProgram) -> bool:
grad_targets = [
spec.target
for spec in ep.graph_signature.output_specs
if spec.kind == OutputKind.GRADIENT_TO_PARAMETER
]
return (
node.op == "placeholder"
and node.target in ep.graph_signature.inputs_to_parameters.keys()
and ep.graph_signature.inputs_to_parameters[node.target] in grad_targets
)


def external_mutable_weights_pass(
gm: GraphModule,
ep: ExportedProgram,
) -> PassResult:
"""
Move all mutable weights to external file.
"""
# pass the gm and the ep seperately as the gm is being mutated by a bunch of passes in to_executorch,
# so the gm in the ep is lagging the graph signature is still correct.
# This is really tech debt and all the passes should be refactored to just mutate the ep.
mutated = False
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
continue

for node in module.graph.nodes:
if node.op == "placeholder":
spec = node.meta.get("spec")
if (
isinstance(spec, TensorSpec)
and spec.const
and _is_mutable_weight(node, ep)
):
node.meta["constant_tag"] = "_default_external_constant"
mutated = True
return PassResult(gm, mutated)
17 changes: 14 additions & 3 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
MemoryFormatOpsPass,
OpReplacePass,
)
from executorch.exir.passes.external_constants_pass import external_constants_pass
from executorch.exir.passes.external_constants_pass import (
external_constants_pass,
external_mutable_weights_pass,
)
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
insert_write_back_for_buffers_pass,
)
Expand Down Expand Up @@ -1395,6 +1398,14 @@ def to_executorch(
# TODO(who?)
p.update_placeholder_tensor_specs(program, new_gm)

# Extract constants if the config says too.
if config.external_constants:
new_gm_res = external_constants_pass(new_gm)
new_gm = new_gm_res.graph_module
elif config.external_mutable_weights:
new_gm_res = external_mutable_weights_pass(new_gm, program)
new_gm = new_gm_res.graph_module

if isinstance(config.memory_planning_pass, dict):
memory_planning_pass = config.memory_planning_pass.get(
name, ExecutorchBackendConfig().memory_planning_pass
Expand All @@ -1409,8 +1420,8 @@ def to_executorch(
else:
new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29]

if config.external_constants:
new_gm_res = external_constants_pass(new_gm_res)
# WARNING: DO NOT ADD ANY MORE PASSES AFTER MEMORY PLANNING PASS.
# THERE ARE A LOT OF ASSUMPTIONS IN THE STACK THAT MEMORY PLANNING IS THE LAST PASS BEFORE THE EMITTER.
assert new_gm_res is not None
new_gm = new_gm_res.graph_module

Expand Down
Loading