Skip to content

Commit a4d5fb9

Browse files
authored
Refactor attention v2
Differential Revision: D73538697 Pull Request resolved: #10623
1 parent 9b14984 commit a4d5fb9

File tree

5 files changed

+73
-22
lines changed

5 files changed

+73
-22
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch.nn.functional as F
1414

1515
from executorch.examples.models.llama.attention import (
16+
Attention,
1617
ATTENTION_REGISTRY,
1718
ForwardOptions,
1819
)
@@ -83,26 +84,46 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8384

8485

8586
class TransformerBlock(nn.Module):
86-
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
87+
def __init__(self, args: ModelArgs, attention: Attention):
88+
"""
89+
Transformer block with support for pre-norm and post-norm.
90+
Args:
91+
args (ModelArgs): model configuration parameters.
92+
attention (Attention): attention object to use in the transformer
93+
block. See `attention.py` for types of attention. Make sure
94+
the attention type is registered in the ATTENTION_REGISTRY.
95+
"""
8796
super().__init__()
8897
self.use_kv_cache = args.use_kv_cache
8998
self.n_heads = args.n_heads
9099
self.dim = args.dim
91100
self.head_dim = args.head_dim
92-
if args.attention_type not in ATTENTION_REGISTRY:
93-
raise ValueError(
94-
f"Unknown attention type: {args.attention_type}. "
95-
f"Available: {list(ATTENTION_REGISTRY.keys())}"
96-
)
97-
cls = ATTENTION_REGISTRY[args.attention_type]
98-
self.attention = cls(args, layer_id, rope)
101+
self.attention = attention
99102
if args.moe:
100103
self.block_sparse_moe = MOEFeedForward(args)
101104
else:
102105
self.feed_forward = FeedForward(args)
103106
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
104107
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
105108

109+
@classmethod
110+
def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
111+
"""
112+
Create a TransformerBlock with the legacy constructor.
113+
Args:
114+
layer_id (int): the index of the layer.
115+
args (ModelArgs): model configuration parameters.
116+
rope (Rope): the rope object to use for rotary embeddings.
117+
"""
118+
if args.attention_type not in ATTENTION_REGISTRY:
119+
raise ValueError(
120+
f"Unknown attention type: {args.attention_type}. "
121+
f"Available: {list(ATTENTION_REGISTRY.keys())}"
122+
)
123+
cls = ATTENTION_REGISTRY[args.attention_type]
124+
attention = cls(args, layer_id, rope)
125+
return TransformerBlock(args, attention)
126+
106127
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
107128
h, attn_options_update = self.attention.forward(
108129
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
@@ -117,7 +138,15 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
117138

118139

119140
class Transformer(nn.Module):
120-
def __init__(self, params: ModelArgs):
141+
def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope):
142+
"""
143+
Transformer model.
144+
Args:
145+
params (ModelArgs): model configuration parameters.
146+
layers (nn.ModuleList): list of transformer blocks - see the
147+
`TransformerBlock` type above.
148+
rope (Rope): the rope object to use for rotary embeddings.
149+
"""
121150
super().__init__()
122151
self.params = params
123152
self.vocab_size = params.vocab_size
@@ -130,10 +159,8 @@ def __init__(self, params: ModelArgs):
130159
if self.apply_embedding
131160
else None
132161
)
133-
self.rope = Rope(params)
134-
self.layers = torch.nn.ModuleList()
135-
for layer_id in range(params.n_layers):
136-
self.layers.append(TransformerBlock(layer_id, params, self.rope))
162+
self.layers = layers
163+
self.rope = rope
137164
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
138165
self.output = (
139166
nn.Linear(params.dim, params.vocab_size, bias=False)
@@ -212,3 +239,23 @@ def forward(
212239
return logits, attn_options_update
213240

214241
return logits
242+
243+
244+
def construct_transformer(model_args: ModelArgs) -> Transformer:
245+
"""
246+
Construct a Transformer model from the given model arguments.
247+
"""
248+
rope = Rope(model_args)
249+
if model_args.attention_type not in ATTENTION_REGISTRY:
250+
raise ValueError(
251+
f"Unknown attention type: {model_args.attention_type}. "
252+
f"Available: {list(ATTENTION_REGISTRY.keys())}"
253+
)
254+
layers = torch.nn.ModuleList()
255+
cls = ATTENTION_REGISTRY[model_args.attention_type]
256+
for layer_id in range(model_args.n_layers):
257+
attention = cls(model_args, layer_id, rope)
258+
transformer_block = TransformerBlock(model_args, attention)
259+
layers.append(transformer_block)
260+
261+
return Transformer(model_args, layers, rope)

examples/models/llama/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
get_checkpoint_dtype,
1616
get_default_model_resource_dir,
1717
)
18-
from executorch.examples.models.llama.llama_transformer import Transformer
1918

19+
from executorch.examples.models.llama.llama_transformer import construct_transformer
2020
from executorch.examples.models.llama.model_args import ModelArgs
21+
from executorch.examples.models.llama.rope import Rope
2122
from torchao.utils import TorchAOBaseTensor
2223

2324
try:
@@ -174,7 +175,7 @@ def __init__(self, **kwargs):
174175
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
175176
with torch.device("meta"):
176177
# Model itself is loaded in default dtype, fp32.
177-
self.model_ = Transformer(model_args)
178+
self.model_ = construct_transformer(model_args)
178179
# Get checkpoint dtype.
179180
if checkpoint:
180181
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)

examples/models/llama/tests/test_pre_quantization_transforms.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import unittest
88

99
import torch
10-
from executorch.examples.models.llama.llama_transformer import Transformer
10+
from executorch.examples.models.llama.llama_transformer import (
11+
construct_transformer,
12+
Transformer,
13+
)
1114
from executorch.examples.models.llama.model_args import ModelArgs
1215
from executorch.examples.models.llama.source_transformation.pre_quantization import (
1316
sanitize_checkpoint_from_pre_quantization,
@@ -39,7 +42,7 @@ def _prepare_dummy_model(self) -> Transformer:
3942
vocab_size=32000,
4043
)
4144

42-
model = Transformer(model_args)
45+
model = construct_transformer(model_args)
4346

4447
return model
4548

examples/models/llama/tests/test_static_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from executorch.examples.models.llama.attention import AttentionMHA, ForwardOptions
5-
from executorch.examples.models.llama.llama_transformer import Transformer
5+
from executorch.examples.models.llama.llama_transformer import construct_transformer
66
from executorch.examples.models.llama.model_args import ModelArgs
77
from executorch.examples.models.llama.rope import Rope
88
from executorch.examples.models.llama.static_attention import (
@@ -160,10 +160,10 @@ def test_within_transformer(self):
160160
n_layers=4,
161161
vocab_size=128,
162162
)
163-
mha_transformer = Transformer(config).eval()
163+
mha_transformer = construct_transformer(config).eval()
164164

165165
config.attention_type = "static"
166-
static_transformer = Transformer(config).eval()
166+
static_transformer = construct_transformer(config).eval()
167167
static_transformer.load_state_dict(mha_transformer.state_dict(), strict=False)
168168
for mha_layer, static_layer in zip(
169169
mha_transformer.layers, static_transformer.layers

examples/models/llava/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import requests
1414
import torch
15-
from executorch.examples.models.llama.llama_transformer import Transformer
15+
from executorch.examples.models.llama.llama_transformer import construct_transformer
1616
from executorch.examples.models.llama.model_args import ModelArgs
1717

1818
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
@@ -66,7 +66,7 @@ def __init__(
6666
use_hf_rope=True,
6767
max_seq_len=max_seq_len,
6868
)
69-
self.text_model = Transformer(self.text_model_args)
69+
self.text_model = construct_transformer(self.text_model_args)
7070
# use custom op for SDPA.
7171
if use_sdpa_with_kv_cache_op:
7272
self.text_model = replace_kv_cache_with_custom_kv_cache(self.text_model)

0 commit comments

Comments
 (0)