-
Notifications
You must be signed in to change notification settings - Fork 608
Refactor attention v2 #10623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor attention v2 #10623
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
import torch.nn.functional as F | ||
|
||
from executorch.examples.models.llama.attention import ( | ||
Attention, | ||
ATTENTION_REGISTRY, | ||
ForwardOptions, | ||
) | ||
|
@@ -83,26 +84,46 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
|
||
|
||
class TransformerBlock(nn.Module): | ||
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope): | ||
def __init__(self, args: ModelArgs, attention: Attention): | ||
""" | ||
Transformer block with support for pre-norm and post-norm. | ||
Args: | ||
args (ModelArgs): model configuration parameters. | ||
attention (Attention): attention object to use in the transformer | ||
block. See `attention.py` for types of attention. Make sure | ||
the attention type is registered in the ATTENTION_REGISTRY. | ||
""" | ||
super().__init__() | ||
self.use_kv_cache = args.use_kv_cache | ||
self.n_heads = args.n_heads | ||
self.dim = args.dim | ||
self.head_dim = args.head_dim | ||
if args.attention_type not in ATTENTION_REGISTRY: | ||
raise ValueError( | ||
f"Unknown attention type: {args.attention_type}. " | ||
f"Available: {list(ATTENTION_REGISTRY.keys())}" | ||
) | ||
cls = ATTENTION_REGISTRY[args.attention_type] | ||
self.attention = cls(args, layer_id, rope) | ||
self.attention = attention | ||
if args.moe: | ||
self.block_sparse_moe = MOEFeedForward(args) | ||
else: | ||
self.feed_forward = FeedForward(args) | ||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) | ||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) | ||
|
||
@classmethod | ||
def from_type(cls, layer_id, args, rope) -> "TransformerBlock": | ||
""" | ||
Create a TransformerBlock with the legacy constructor. | ||
Args: | ||
layer_id (int): the index of the layer. | ||
args (ModelArgs): model configuration parameters. | ||
rope (Rope): the rope object to use for rotary embeddings. | ||
""" | ||
if args.attention_type not in ATTENTION_REGISTRY: | ||
raise ValueError( | ||
f"Unknown attention type: {args.attention_type}. " | ||
f"Available: {list(ATTENTION_REGISTRY.keys())}" | ||
) | ||
cls = ATTENTION_REGISTRY[args.attention_type] | ||
attention = cls(args, layer_id, rope) | ||
return TransformerBlock(args, attention) | ||
|
||
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN | ||
h, attn_options_update = self.attention.forward( | ||
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options | ||
|
@@ -117,7 +138,15 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: | |
|
||
|
||
class Transformer(nn.Module): | ||
def __init__(self, params: ModelArgs): | ||
def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if you are going to do this, might as well lift all of the major model components out as well, such as the embedding layer and rms norm, even though they are not customizable by model args at the moment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can, but would prefer to have it in a separate PR if it's something we want to do. Is there a use-case, or more to make Transformer more modular? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Up to you, no use case atm, just for modularity. Just feels a bit weird to me seeing layers and rope be the only lifted inputs for Transformer |
||
""" | ||
Transformer model. | ||
Args: | ||
params (ModelArgs): model configuration parameters. | ||
layers (nn.ModuleList): list of transformer blocks - see the | ||
`TransformerBlock` type above. | ||
rope (Rope): the rope object to use for rotary embeddings. | ||
""" | ||
super().__init__() | ||
self.params = params | ||
self.vocab_size = params.vocab_size | ||
|
@@ -130,10 +159,8 @@ def __init__(self, params: ModelArgs): | |
if self.apply_embedding | ||
else None | ||
) | ||
self.rope = Rope(params) | ||
self.layers = torch.nn.ModuleList() | ||
for layer_id in range(params.n_layers): | ||
self.layers.append(TransformerBlock(layer_id, params, self.rope)) | ||
self.layers = layers | ||
self.rope = rope | ||
self.norm = RMSNorm(params.dim, eps=params.norm_eps) | ||
self.output = ( | ||
nn.Linear(params.dim, params.vocab_size, bias=False) | ||
|
@@ -212,3 +239,23 @@ def forward( | |
return logits, attn_options_update | ||
|
||
return logits | ||
|
||
|
||
def construct_transformer(model_args: ModelArgs) -> Transformer: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. discussed offline; construct_transformer is likely going to be more high-level; not quite at model-creation, but will contain eg. lora instantiation so may not make sense for it to be part of the transformer class itself. |
||
""" | ||
Construct a Transformer model from the given model arguments. | ||
""" | ||
rope = Rope(model_args) | ||
if model_args.attention_type not in ATTENTION_REGISTRY: | ||
raise ValueError( | ||
f"Unknown attention type: {model_args.attention_type}. " | ||
f"Available: {list(ATTENTION_REGISTRY.keys())}" | ||
) | ||
layers = torch.nn.ModuleList() | ||
cls = ATTENTION_REGISTRY[model_args.attention_type] | ||
for layer_id in range(model_args.n_layers): | ||
attention = cls(model_args, layer_id, rope) | ||
transformer_block = TransformerBlock(model_args, attention) | ||
layers.append(transformer_block) | ||
|
||
return Transformer(model_args, layers, rope) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add doc string on each argument, especially the attention? I think it makes sense to me that Attention type is required, so that the API of user-defined attention is compatible with our transformer.