Skip to content

Commit 6375fc2

Browse files
committed
Refactor attention v2
Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer. The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well. This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py. I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer. It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221 Differential Revision: [D73538697](https://our.internmc.facebook.com/intern/diff/D73538697/) [ghstack-poisoned]
1 parent 280db15 commit 6375fc2

File tree

4 files changed

+45
-20
lines changed

4 files changed

+45
-20
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch.nn.functional as F
1414

1515
from executorch.examples.models.llama.attention import (
16+
Attention,
1617
ATTENTION_REGISTRY,
1718
ForwardOptions,
1819
)
@@ -83,25 +84,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8384

8485

8586
class TransformerBlock(nn.Module):
86-
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
87+
def __init__(self, args: ModelArgs, attention: Attention):
8788
super().__init__()
8889
self.use_kv_cache = args.use_kv_cache
8990
self.n_heads = args.n_heads
9091
self.dim = args.dim
9192
self.head_dim = args.head_dim
92-
if args.attention_type not in ATTENTION_REGISTRY:
93-
raise ValueError(
94-
f"Unknown attention type: {args.attention_type}. "
95-
f"Available: {list(ATTENTION_REGISTRY.keys())}"
96-
)
97-
cls = ATTENTION_REGISTRY[args.attention_type]
98-
self.attention = cls(args, layer_id, rope)
93+
self.attention = attention
9994
if args.moe:
10095
self.block_sparse_moe = MOEFeedForward(args)
10196
else:
10297
self.feed_forward = FeedForward(args)
10398
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
10499
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
100+
101+
@classmethod
102+
def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
103+
if args.attention_type not in ATTENTION_REGISTRY:
104+
raise ValueError(
105+
f"Unknown attention type: {args.attention_type}. "
106+
f"Available: {list(ATTENTION_REGISTRY.keys())}"
107+
)
108+
cls = ATTENTION_REGISTRY[args.attention_type]
109+
attention = cls(args, layer_id, rope)
110+
return TransformerBlock(args, attention)
105111

106112
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
107113
h, attn_options_update = self.attention.forward(
@@ -117,7 +123,7 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
117123

118124

119125
class Transformer(nn.Module):
120-
def __init__(self, params: ModelArgs):
126+
def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope):
121127
super().__init__()
122128
self.params = params
123129
self.vocab_size = params.vocab_size
@@ -130,10 +136,8 @@ def __init__(self, params: ModelArgs):
130136
if self.apply_embedding
131137
else None
132138
)
133-
self.rope = Rope(params)
134-
self.layers = torch.nn.ModuleList()
135-
for layer_id in range(params.n_layers):
136-
self.layers.append(TransformerBlock(layer_id, params, self.rope))
139+
self.layers = layers
140+
self.rope = rope
137141
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
138142
self.output = (
139143
nn.Linear(params.dim, params.vocab_size, bias=False)
@@ -212,3 +216,23 @@ def forward(
212216
return logits, attn_options_update
213217

214218
return logits
219+
220+
221+
def construct_transformer(model_args: ModelArgs) -> Transformer:
222+
"""
223+
Construct a Transformer model from the given model arguments.
224+
"""
225+
rope = Rope(model_args)
226+
if model_args.attention_type not in ATTENTION_REGISTRY:
227+
raise ValueError(
228+
f"Unknown attention type: {model_args.attention_type}. "
229+
f"Available: {list(ATTENTION_REGISTRY.keys())}"
230+
)
231+
layers = torch.nn.ModuleList()
232+
cls = ATTENTION_REGISTRY[model_args.attention_type]
233+
for layer_id in range(model_args.n_layers):
234+
attention = cls(model_args, layer_id, rope)
235+
transformer_block = TransformerBlock(model_args, attention)
236+
layers.append(transformer_block)
237+
238+
return Transformer(model_args, layers, rope)

examples/models/llama/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
get_checkpoint_dtype,
1616
get_default_model_resource_dir,
1717
)
18-
from executorch.examples.models.llama.llama_transformer import Transformer
1918

19+
from executorch.examples.models.llama.llama_transformer import construct_transformer
2020
from executorch.examples.models.llama.model_args import ModelArgs
21+
from executorch.examples.models.llama.rope import Rope
2122
from torchao.utils import TorchAOBaseTensor
2223

2324
try:
@@ -174,7 +175,7 @@ def __init__(self, **kwargs):
174175
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
175176
with torch.device("meta"):
176177
# Model itself is loaded in default dtype, fp32.
177-
self.model_ = Transformer(model_args)
178+
self.model_ = construct_transformer(model_args)
178179
# Get checkpoint dtype.
179180
if checkpoint:
180181
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)

examples/models/llama/tests/test_pre_quantization_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import unittest
88

99
import torch
10-
from executorch.examples.models.llama.llama_transformer import Transformer
10+
from executorch.examples.models.llama.llama_transformer import construct_transformer, Transformer
1111
from executorch.examples.models.llama.model_args import ModelArgs
1212
from executorch.examples.models.llama.source_transformation.pre_quantization import (
1313
sanitize_checkpoint_from_pre_quantization,
@@ -39,7 +39,7 @@ def _prepare_dummy_model(self) -> Transformer:
3939
vocab_size=32000,
4040
)
4141

42-
model = Transformer(model_args)
42+
model = construct_transformer(model_args)
4343

4444
return model
4545

examples/models/llama/tests/test_static_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from executorch.examples.models.llama.attention import AttentionMHA, ForwardOptions
5-
from executorch.examples.models.llama.llama_transformer import Transformer
5+
from executorch.examples.models.llama.llama_transformer import construct_transformer
66
from executorch.examples.models.llama.model_args import ModelArgs
77
from executorch.examples.models.llama.rope import Rope
88
from executorch.examples.models.llama.static_attention import (
@@ -160,10 +160,10 @@ def test_within_transformer(self):
160160
n_layers=4,
161161
vocab_size=128,
162162
)
163-
mha_transformer = Transformer(config).eval()
163+
mha_transformer = construct_transformer(config).eval()
164164

165165
config.attention_type = "static"
166-
static_transformer = Transformer(config).eval()
166+
static_transformer = construct_transformer(config).eval()
167167
static_transformer.load_state_dict(mha_transformer.state_dict(), strict=False)
168168
for mha_layer, static_layer in zip(
169169
mha_transformer.layers, static_transformer.layers

0 commit comments

Comments
 (0)