-
Notifications
You must be signed in to change notification settings - Fork 607
[ExecuTorch][BE] Split kv cache and SDPA for better code sharing #7413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: This enables us to do more easier module swap with model definitions from torchtune Test Plan: CI Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/7413
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit b6a4eb5 with merge base f4e77c7 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…for better code sharing" Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress [ghstack-poisoned]
…sharing" Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress [ghstack-poisoned]
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress ghstack-source-id: 6289ce2 Pull Request resolved: #7413
@@ -212,6 +215,13 @@ def export(self) -> "LLMEdgeManager": | |||
|
|||
return self | |||
|
|||
def run_canonical_optimizations(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment on this function
@@ -47,20 +37,21 @@ def forward( | |||
seqlen, | |||
mask, | |||
): | |||
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just thought about it again, and adding this transpose here and also before in the llama_transformer.py
so that we can share code for kv_cache.py (this is the reason right?) doesn't really make sense since we are using a custom export-friendly KV cache already anyways: https://github.com/pytorch/executorch/blob/main/extension/llm/modules/kv_cache.py#L13
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is really standardizing what the SDPA api is. That is the input tensors q, k and v are in [bs, num heads, seq len, head dim]. This is the standard API. This actually applies to output as well, but that part i havent yet fixed.
So now imagine you did not use kv cache but just wanted to replace SDPA. This allows you to do that.
Regarding KV cache itself, the transpose_cache
argument is also relic from the time of using sdpa_with_kv_cache. We dont really need tranpose_cache arg at all. Removing that will bring it closer to tune's kv cache impl. In fact I would like us to not have to use own own kv cache.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I see you are dealing with the paired transposes later anyway. Also extra bit of info is that from my own benchmarking, these transposes here don't seem to be much overhead
@@ -212,6 +215,13 @@ def export(self) -> "LLMEdgeManager": | |||
|
|||
return self | |||
|
|||
def run_canonical_optimizations(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also add checks to make sure self.pre_autograd_graph_module
is not None, basically this needs to be run after export()
.
…for better code sharing" Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress [ghstack-poisoned]
…sharing" Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress [ghstack-poisoned]
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress ghstack-source-id: 05da5d0 Pull Request resolved: #7413
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good 👍🏻 pending comments + testing
@@ -47,20 +37,21 @@ def forward( | |||
seqlen, | |||
mask, | |||
): | |||
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I see you are dealing with the paired transposes later anyway. Also extra bit of info is that from my own benchmarking, these transposes here don't seem to be much overhead
return dim_0, dim_1 | ||
|
||
|
||
class RemoveRedundantTransposes(ExportPass): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would name something else, since this is also undoing permutes as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good point. I should rename it to permute as transpose is special case of permute as well
graph_module.graph.eliminate_dead_code() | ||
graph_module.recompile() | ||
|
||
return PassResult(graph_module, graph_changed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we just make this a default pass and run this in to_edge always?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So for now I am gonna add it here in next PR. I need to keep this here just to not break any backend delegation stuff for now.
Update: synced offline and wanted to put the notes from the sync somewhere:
|
…for better code sharing" Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
…sharing" Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress ghstack-source-id: 14f5d76 Pull Request resolved: #7413
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…for better code sharing" Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
…sharing" Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
…sharing" Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress ghstack-source-id: 1079c1d Pull Request resolved: #7413
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…for better code sharing" Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
…sharing" Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress ghstack-source-id: abaea2c Pull Request resolved: #7413
…for better code sharing" Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
…sharing" Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress ghstack-source-id: a5601e0 Pull Request resolved: #7413
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…for better code sharing" Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
…sharing" Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress ghstack-source-id: fbb04f1 Pull Request resolved: #7413
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…KV cache update operator ## Context #7413 and #7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation. As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators. Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned. Differential Revision: [D68916952](https://our.internmc.facebook.com/intern/diff/D68916952/) [ghstack-poisoned]
…KV cache update operator ## Context #7413 and #7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation. As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators. Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned. Differential Revision: [D68916952](https://our.internmc.facebook.com/intern/diff/D68916952/) [ghstack-poisoned]
…KV cache update operator ## Context #7413 and #7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation. As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators. Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned. Differential Revision: [D68916952](https://our.internmc.facebook.com/intern/diff/D68916952/) [ghstack-poisoned]
…KV cache update operator ## Context #7413 and #7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation. As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators. Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned. Differential Revision: [D68919676](https://our.internmc.facebook.com/intern/diff/D68919676/) [ghstack-poisoned]
…KV cache update operator ## Context #7413 and #7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation. As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators. Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned. Differential Revision: [D68919676](https://our.internmc.facebook.com/intern/diff/D68919676/) [ghstack-poisoned]
…KV cache update operator ## Context #7413 and #7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation. As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators. Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned. Differential Revision: [D68919676](https://our.internmc.facebook.com/intern/diff/D68919676/) ghstack-source-id: 263930059 Pull Request resolved: #8068
…KV cache update operator + Add `RemoveAsserts` pass and apply it during LlaMa export **Note**: This diff is a combination of D68919676 (#8068) and D68919678 (no pull request). I decided to combine the two because of problems with `ghexport`, which was having some problems exporting the second diff, as well as the fact that both diffs are needed for `export_llama` to work so it makes more sense to just have a single diff. ## Context #7413 and #7412 split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation. As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators. Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned. ## Context Recently, some assertion ops were added to the Llama source code. Unfortunately, this causes issues for the Vulkan delegate because runtime assertions are not yet supported in Vulkan and the assertion ops cause graph breaks due to not being supported. To prevent graph breaks when delegating to Vulkan, apply a pass to remove assertion ops during the llama export. Differential Revision: [D68922404](https://our.internmc.facebook.com/intern/diff/D68922404/) [ghstack-poisoned]
Differential Revision: D67914054 Pull Request resolved: pytorch#7413
Stack from ghstack (oldest at bottom):
Summary:
Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.
How.
Earlier PRs decoupled kv cache update from sdpa. So now
both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
tensors.
replaced by custom modules, but we will write graph pass to undo
those.
Test Plan:
Existing tests.
Make sure perf doesnt regress
Differential Revision: D67914054