Skip to content

Commit 4903f0a

Browse files
authored
Keep LiftedConstant in .pte (#9202)
Summary: See #8809 for context LiftedConstants should not be moved to external file (data), as they are closer semantically to code Differential Revision: D71064053
1 parent 16d3db2 commit 4903f0a

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

exir/emit/test/test_emit.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,26 +1534,49 @@ def forward(self, x):
15341534
self.assertEqual(len(program.constant_buffer[1].storage), 8)
15351535

15361536
def test_emit_lifted_tensor_constant(self) -> None:
1537-
class LiftedConstants(nn.Module):
1537+
class LiftedTensorConstants(nn.Module):
15381538
def __init__(self):
15391539
super().__init__()
15401540

15411541
def forward(self, x):
15421542
x = x * torch.tensor([[4, 3], [1, 2], [5, 6]], dtype=torch.float)
15431543
return x
15441544

1545-
model = LiftedConstants()
1545+
model = LiftedTensorConstants()
1546+
# Specify that we want to move non-lifted constants to external file
1547+
et_cfg = ExecutorchBackendConfig(external_constants=True)
1548+
program = to_edge(
1549+
export(model, (torch.ones(3, 2),), strict=True)
1550+
).to_executorch(et_cfg)
1551+
program = program._emitter_output.program
1552+
exec_plan = program.execution_plan[0]
1553+
# There should only be 1 input to this model.
1554+
self.assertEqual(len(exec_plan.inputs), 1)
1555+
self.assertEqual(len(program.constant_buffer), 2)
1556+
self.assertEqual(len(program.constant_buffer[1].storage), 24)
15461557

1558+
def test_emit_lifted_constant(self) -> None:
1559+
class LiftedConstants(nn.Module):
1560+
def __init__(self):
1561+
super().__init__()
1562+
1563+
def forward(self, x):
1564+
x = x + 1
1565+
return x
1566+
1567+
model = LiftedConstants()
1568+
# Specify that we want to move non-lifted constants to external file
1569+
et_cfg = ExecutorchBackendConfig(external_constants=True)
15471570
program = to_edge(
15481571
export(model, (torch.ones(3, 2),), strict=True)
1549-
).to_executorch()
1572+
).to_executorch(et_cfg)
15501573

15511574
program = program._emitter_output.program
15521575
exec_plan = program.execution_plan[0]
15531576
# There should only be 1 input to this model.
15541577
self.assertEqual(len(exec_plan.inputs), 1)
15551578
self.assertEqual(len(program.constant_buffer), 2)
1556-
self.assertEqual(len(program.constant_buffer[1].storage), 24)
1579+
self.assertEqual(len(program.constant_buffer[1].storage), 8)
15571580

15581581
def test_mutable_buffers(self) -> None:
15591582
def count_copies(gm: torch.fx.GraphModule) -> int:

exir/passes/external_constants_pass.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,17 @@ def external_constants_pass(
1717
gm: GraphModule,
1818
) -> PassResult:
1919
"""
20-
Move all constants to external file.
20+
Move all non-lifted constants to external file.
21+
NOTE: Lifted constants are not moved as they are closer
22+
to code than data.
2123
"""
2224
mutated = False
2325
for module in gm.modules():
2426
if not isinstance(module, torch.fx.GraphModule):
2527
continue
2628

2729
for node in module.graph.nodes:
28-
if node.op == "placeholder":
30+
if (node.op == "placeholder") and ("_lifted_tensor" not in node.name):
2931
spec = node.meta.get("spec")
3032
if isinstance(spec, TensorSpec) and spec.const:
3133
node.meta["constant_tag"] = "_default_external_constant"

0 commit comments

Comments
 (0)