Skip to content

Commit 61101c2

Browse files
committed
Add pass
1 parent 2a3fe8b commit 61101c2

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from executorch.exir.pass_base import ExportPass
9+
10+
11+
class CachePosToInitializedMutableBufferPass(ExportPass):
12+
"""
13+
If the buffer has the name "cache_pos", such as in an kv_cache
14+
module with `self.register_buffer("cache_pos", torch.arange(10))`,
15+
mark it with a custom tag which later is used by the emitter to
16+
flag spec.const to True, which provides the mutable buffer with
17+
an initialized state.
18+
"""
19+
20+
def __init__(self) -> None:
21+
super().__init__()
22+
23+
def placeholder(self, name: str, arg, meta):
24+
if "cache_pos" in name:
25+
meta["et_init_buffer"] = True
26+
27+
return super().placeholder(name, arg, meta)

0 commit comments

Comments
 (0)