Skip to content

Commit 93f99ad

Browse files
committed
Tests
1 parent a2b7ee3 commit 93f99ad

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

exir/emit/_emitter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1657,8 +1657,7 @@ def placeholder(
16571657
spec.storage = real_tensor.untyped_storage()
16581658

16591659
# User inputs and mutable buffers are not constants, other buffers or parameters are.
1660-
if initialize_buffer:
1661-
assert is_mutable_buffer
1660+
if initialize_buffer and is_mutable_buffer:
16621661
spec.const = True
16631662
else:
16641663
spec.const = not (is_user_input or is_mutable_buffer)

exir/emit/test/test_emit.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import typing
1010
import unittest
1111
from contextlib import contextmanager
12+
from copy import deepcopy
1213
from typing import List, Optional, Tuple
1314

1415
import executorch.exir as exir
@@ -31,6 +32,7 @@
3132
from executorch.exir.error import InternalError
3233
from executorch.exir.passes import MemoryPlanningPass
3334
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
35+
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
3436
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
3537
from executorch.exir.print_program import pretty_print, print_program # noqa
3638
from executorch.exir.schema import (
@@ -56,6 +58,7 @@
5658
from executorch.extension.pybindings.portable_lib import (
5759
_load_for_executorch_from_buffer,
5860
)
61+
from executorch.runtime import Runtime
5962

6063
from functorch.experimental import control_flow
6164
from torch import nn
@@ -243,6 +246,55 @@ def forward(self, x):
243246
)
244247
self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null)
245248

249+
def test_initialized_mutable_buffer(self):
250+
"""Test that mutable buffers can hold meaningful initialized state."""
251+
252+
class TestModule(torch.nn.Module):
253+
def __init__(self):
254+
super().__init__()
255+
# Mutable buffer with non-empty initial state.
256+
self.register_buffer("cache_pos", torch.arange(0, 10))
257+
258+
def forward(self, x):
259+
self.cache_pos.add_(1)
260+
return self.cache_pos
261+
262+
m = TestModule()
263+
example_inputs = (torch.ones(10),)
264+
ep = torch.export.export(m, example_inputs)
265+
edge = to_edge(
266+
ep,
267+
compile_config=EdgeCompileConfig(
268+
_check_ir_validity=False,
269+
),
270+
)
271+
272+
# Save a copy of the edge program since to_executorch is
273+
# stateful to sombe degree.
274+
edge_copy = deepcopy(edge)
275+
et_config = ExecutorchBackendConfig(
276+
passes=[InitializedMutableBufferPass(["cache_pos"])],
277+
)
278+
et_program_init_pass = edge.to_executorch(config=et_config)
279+
et_program_regular = edge_copy.to_executorch()
280+
281+
runtime = Runtime.get()
282+
program_init_pass = runtime.load_program(et_program_init_pass.buffer)
283+
method_init_pass = program_init_pass.load_method("forward")
284+
285+
program_regular = runtime.load_program(et_program_regular.buffer)
286+
method_regular = program_regular.load_method("forward")
287+
288+
# Test that the mutable buffer is initialized.
289+
torch.allclose(
290+
method_init_pass.execute((example_inputs))[0], torch.arange(1, 11)
291+
)
292+
# Test that the mutable buffer is uninitialized and starts with default zeros.
293+
torch.allclose(
294+
method_regular.execute((example_inputs))[0],
295+
torch.ones(10, dtype=torch.int64),
296+
)
297+
246298
def test_int_list_input(self):
247299
class M(torch.nn.Module):
248300
def forward(self, x, y, z):

0 commit comments

Comments
 (0)