Skip to content

Refactor attention v2 #10623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 6, 2025
Merged

Conversation

lucylq
Copy link
Contributor

@lucylq lucylq commented May 1, 2025

Stack from ghstack (oldest at bottom):

Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer.

The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well.

This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py.

I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer.

It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221

Differential Revision: D73538697

Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer.

The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well.

This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py.

I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer.

It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221

Differential Revision: [D73538697](https://our.internmc.facebook.com/intern/diff/D73538697/)

[ghstack-poisoned]
@lucylq lucylq requested a review from jackzhxng as a code owner May 1, 2025 20:33
lucylq added a commit that referenced this pull request May 1, 2025
Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer.

The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well.

This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py.

I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer.

It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221

Differential Revision: [D73538697](https://our.internmc.facebook.com/intern/diff/D73538697/)

ghstack-source-id: 279884551
Pull Request resolved: #10623
Copy link

pytorch-bot bot commented May 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/10623

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 New Failures

As of commit a3fc78a with merge base 4129ebe (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 1, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73538697

Copy link
Contributor

@iseeyuan iseeyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense to me to pass attention to the transformer, and unblock LoRA.

@@ -83,25 +84,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
def __init__(self, args: ModelArgs, attention: Attention):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add doc string on each argument, especially the attention? I think it makes sense to me that Attention type is required, so that the API of user-defined attention is compatible with our transformer.

Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer.

The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well.

This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py.

I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer.

It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221

Differential Revision: [D73538697](https://our.internmc.facebook.com/intern/diff/D73538697/)

[ghstack-poisoned]
lucylq added a commit that referenced this pull request May 3, 2025
Pull Request resolved: #10623

Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer.

The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well.

This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py.

I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer.

It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221

Previously here: D73474110

ghstack-source-id: 281805091
@exported-using-ghexport

Differential Revision: [D73538697](https://our.internmc.facebook.com/intern/diff/D73538697/)
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73538697

Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer.

The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well.

This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py.

I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer.

It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221

Differential Revision: [D73538697](https://our.internmc.facebook.com/intern/diff/D73538697/)

[ghstack-poisoned]
lucylq added a commit that referenced this pull request May 5, 2025
Pull Request resolved: #10623

Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer.

The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well.

This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py.

I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer.

It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221

Previously here: D73474110

ghstack-source-id: 282064227
@exported-using-ghexport

Differential Revision: [D73538697](https://our.internmc.facebook.com/intern/diff/D73538697/)
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73538697

@@ -117,7 +138,15 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:


class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you are going to do this, might as well lift all of the major model components out as well, such as the embedding layer and rms norm, even though they are not customizable by model args at the moment

Copy link
Contributor Author

@lucylq lucylq May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can, but would prefer to have it in a separate PR if it's something we want to do. Is there a use-case, or more to make Transformer more modular?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to you, no use case atm, just for modularity. Just feels a bit weird to me seeing layers and rope be the only lifted inputs for Transformer

@@ -212,3 +239,23 @@ def forward(
return logits, attn_options_update

return logits


def construct_transformer(model_args: ModelArgs) -> Transformer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not @classmethod?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline; construct_transformer is likely going to be more high-level; not quite at model-creation, but will contain eg. lora instantiation so may not make sense for it to be part of the transformer class itself.

Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer.

The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well.

This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py.

I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer.

It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221

Differential Revision: [D73538697](https://our.internmc.facebook.com/intern/diff/D73538697/)

[ghstack-poisoned]
lucylq added a commit that referenced this pull request May 5, 2025
Pull Request resolved: #10623

Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer.

The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well.

This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py.

I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer.

It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221

Previously here: D73474110

ghstack-source-id: 282118266
@exported-using-ghexport

Differential Revision: [D73538697](https://our.internmc.facebook.com/intern/diff/D73538697/)
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73538697

@facebook-github-bot facebook-github-bot merged commit a4d5fb9 into gh/lucylq/74/base May 6, 2025
161 of 167 checks passed
@facebook-github-bot facebook-github-bot deleted the gh/lucylq/74/head branch May 6, 2025 03:59
kirklandsign pushed a commit that referenced this pull request May 6, 2025
Pull Request resolved: #10623

Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer.

The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well.

This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py.

I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer.

It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221

Previously here: D73474110

ghstack-source-id: 282118266
@exported-using-ghexport

Differential Revision: [D73538697](https://our.internmc.facebook.com/intern/diff/D73538697/)
phaiting pushed a commit that referenced this pull request May 6, 2025
Pull Request resolved: #10623

Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer.

The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well.

This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py.

I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer.

It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221

Previously here: D73474110

ghstack-source-id: 282118266
@exported-using-ghexport

Differential Revision: [D73538697](https://our.internmc.facebook.com/intern/diff/D73538697/)
jhelsby pushed a commit to jhelsby/executorch that referenced this pull request May 9, 2025
Pull Request resolved: pytorch#10623

Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer.

The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well.

This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py.

I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer.

It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221

Previously here: D73474110

ghstack-source-id: 282118266
@exported-using-ghexport

Differential Revision: [D73538697](https://our.internmc.facebook.com/intern/diff/D73538697/)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported topic: not user facing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants