Skip to content

Commit 01bff6b

Browse files
committed
Update on "Refactor attention v2"
Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer. The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well. This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py. I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer. It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221 Differential Revision: [D73538697](https://our.internmc.facebook.com/intern/diff/D73538697/) [ghstack-poisoned]
2 parents 6dd31a8 + 32a40b0 commit 01bff6b

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8585

8686
class TransformerBlock(nn.Module):
8787
def __init__(self, args: ModelArgs, attention: Attention):
88+
"""
89+
Transformer block with support for pre-norm and post-norm.
90+
Args:
91+
args (ModelArgs): model configuration parameters.
92+
attention (Attention): attention object to use in the transformer
93+
block. See `attention.py` for types of attention. Make sure
94+
the attention type is registered in the ATTENTION_REGISTRY.
95+
"""
8896
super().__init__()
8997
self.use_kv_cache = args.use_kv_cache
9098
self.n_heads = args.n_heads
@@ -100,6 +108,13 @@ def __init__(self, args: ModelArgs, attention: Attention):
100108

101109
@classmethod
102110
def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
111+
"""
112+
Create a TransformerBlock with the legacy constructor.
113+
Args:
114+
layer_id (int): the index of the layer.
115+
args (ModelArgs): model configuration parameters.
116+
rope (Rope): the rope object to use for rotary embeddings.
117+
"""
103118
if args.attention_type not in ATTENTION_REGISTRY:
104119
raise ValueError(
105120
f"Unknown attention type: {args.attention_type}. "
@@ -124,6 +139,14 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
124139

125140
class Transformer(nn.Module):
126141
def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope):
142+
"""
143+
Transformer model.
144+
Args:
145+
params (ModelArgs): model configuration parameters.
146+
layers (nn.ModuleList): list of transformer blocks - see the
147+
`TransformerBlock` type above.
148+
rope (Rope): the rope object to use for rotary embeddings.
149+
"""
127150
super().__init__()
128151
self.params = params
129152
self.vocab_size = params.vocab_size

0 commit comments

Comments
 (0)