-
Notifications
You must be signed in to change notification settings - Fork 608
Fix executorch kv cache incompatibility with to_executorch lowering #7279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
aac90a0
Add tests that localize the prefill issue to the kv cache
jackzhxng 917fb0d
Fixes test but not model
jackzhxng 46ea733
Updated pass
jackzhxng 5db136c
Fix segmentation fault
jackzhxng 9cdfb43
Lint
jackzhxng 9e68531
Only add pass when vision model
jackzhxng 925409d
Add comments
jackzhxng 2a3fe8b
Remove import
jackzhxng 61101c2
Add pass
jackzhxng 4ee95d3
PR review
jackzhxng e297c9b
Fix test
jackzhxng 8145cda
Last changes
jackzhxng 73591f1
Merge branch 'main' into jz/fix-prefill
jackzhxng a2b7ee3
Update attention test
jackzhxng 93f99ad
Tests
jackzhxng 69e36fb
Dave pr comment
jackzhxng 5c53856
Merge branch 'main' into jz/fix-prefill
jackzhxng 9d84a42
Merge branch 'main' into jz/fix-prefill
jackzhxng 6fe376d
Jacob pr review
jackzhxng File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
dbort marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
from typing import List | ||
|
||
from executorch.exir.pass_base import ExportPass | ||
|
||
|
||
class InitializedMutableBufferPass(ExportPass): | ||
""" | ||
If a buffer has a name that within a specified list, set meta["et_init_buffer"] | ||
to True, which provides the mutable buffer with an initialized state. | ||
As an example, a module with `self.register_buffer("cache_pos", torch.arange(10))` | ||
when patterns = ["cache_pos"] would have its initial state set instead of being | ||
left uninitialized by default. | ||
""" | ||
|
||
def __init__(self, patterns: List[str]) -> None: | ||
super().__init__() | ||
self.patterns = patterns | ||
|
||
def placeholder(self, name: str, arg, meta): | ||
for pattern in self.patterns: | ||
if pattern in name: | ||
meta["et_init_buffer"] = True | ||
|
||
jackzhxng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return super().placeholder(name, arg, meta) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be
zeros
, based on the comment? If not, please update the comment to clarify why this isones
. And if it should bezeros
, did this test fail?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, it's because in the forward of the model we do
self.cache_pos += 1
, I'll specify this