|
9 | 9 | import typing
|
10 | 10 | import unittest
|
11 | 11 | from contextlib import contextmanager
|
| 12 | +from copy import deepcopy |
12 | 13 | from typing import List, Optional, Tuple
|
13 | 14 |
|
14 | 15 | import executorch.exir as exir
|
|
31 | 32 | from executorch.exir.error import InternalError
|
32 | 33 | from executorch.exir.passes import MemoryPlanningPass
|
33 | 34 | from executorch.exir.passes.constant_prop_pass import constant_prop_pass
|
| 35 | +from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass |
34 | 36 | from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
|
35 | 37 | from executorch.exir.print_program import pretty_print, print_program # noqa
|
36 | 38 | from executorch.exir.schema import (
|
|
56 | 58 | from executorch.extension.pybindings.portable_lib import (
|
57 | 59 | _load_for_executorch_from_buffer,
|
58 | 60 | )
|
| 61 | +from executorch.runtime import Runtime |
59 | 62 |
|
60 | 63 | from functorch.experimental import control_flow
|
61 | 64 | from torch import nn
|
@@ -243,6 +246,55 @@ def forward(self, x):
|
243 | 246 | )
|
244 | 247 | self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null)
|
245 | 248 |
|
| 249 | + def test_initialized_mutable_buffer(self): |
| 250 | + """Test that mutable buffers can hold meaningful initialized state.""" |
| 251 | + |
| 252 | + class TestModule(torch.nn.Module): |
| 253 | + def __init__(self): |
| 254 | + super().__init__() |
| 255 | + # Mutable buffer with non-empty initial state. |
| 256 | + self.register_buffer("cache_pos", torch.arange(0, 10)) |
| 257 | + |
| 258 | + def forward(self, x): |
| 259 | + self.cache_pos.add_(1) |
| 260 | + return self.cache_pos |
| 261 | + |
| 262 | + m = TestModule() |
| 263 | + example_inputs = (torch.ones(10),) |
| 264 | + ep = torch.export.export(m, example_inputs) |
| 265 | + edge = to_edge( |
| 266 | + ep, |
| 267 | + compile_config=EdgeCompileConfig( |
| 268 | + _check_ir_validity=False, |
| 269 | + ), |
| 270 | + ) |
| 271 | + |
| 272 | + # Save a copy of the edge program since to_executorch is |
| 273 | + # stateful to sombe degree. |
| 274 | + edge_copy = deepcopy(edge) |
| 275 | + et_config = ExecutorchBackendConfig( |
| 276 | + passes=[InitializedMutableBufferPass(["cache_pos"])], |
| 277 | + ) |
| 278 | + et_program_init_pass = edge.to_executorch(config=et_config) |
| 279 | + et_program_regular = edge_copy.to_executorch() |
| 280 | + |
| 281 | + runtime = Runtime.get() |
| 282 | + program_init_pass = runtime.load_program(et_program_init_pass.buffer) |
| 283 | + method_init_pass = program_init_pass.load_method("forward") |
| 284 | + |
| 285 | + program_regular = runtime.load_program(et_program_regular.buffer) |
| 286 | + method_regular = program_regular.load_method("forward") |
| 287 | + |
| 288 | + # Test that the mutable buffer is initialized. |
| 289 | + torch.allclose( |
| 290 | + method_init_pass.execute((example_inputs))[0], torch.arange(1, 11) |
| 291 | + ) |
| 292 | + # Test that the mutable buffer is uninitialized and starts with default zeros. |
| 293 | + torch.allclose( |
| 294 | + method_regular.execute((example_inputs))[0], |
| 295 | + torch.ones(10, dtype=torch.int64), |
| 296 | + ) |
| 297 | + |
246 | 298 | def test_int_list_input(self):
|
247 | 299 | class M(torch.nn.Module):
|
248 | 300 | def forward(self, x, y, z):
|
|
0 commit comments