13
13
import torch .nn .functional as F
14
14
15
15
from executorch .examples .models .llama .attention import (
16
+ Attention ,
16
17
ATTENTION_REGISTRY ,
17
18
ForwardOptions ,
18
19
)
@@ -83,26 +84,46 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
84
84
85
85
86
class TransformerBlock (nn .Module ):
86
- def __init__ (self , layer_id : int , args : ModelArgs , rope : Rope ):
87
+ def __init__ (self , args : ModelArgs , attention : Attention ):
88
+ """
89
+ Transformer block with support for pre-norm and post-norm.
90
+ Args:
91
+ args (ModelArgs): model configuration parameters.
92
+ attention (Attention): attention object to use in the transformer
93
+ block. See `attention.py` for types of attention. Make sure
94
+ the attention type is registered in the ATTENTION_REGISTRY.
95
+ """
87
96
super ().__init__ ()
88
97
self .use_kv_cache = args .use_kv_cache
89
98
self .n_heads = args .n_heads
90
99
self .dim = args .dim
91
100
self .head_dim = args .head_dim
92
- if args .attention_type not in ATTENTION_REGISTRY :
93
- raise ValueError (
94
- f"Unknown attention type: { args .attention_type } . "
95
- f"Available: { list (ATTENTION_REGISTRY .keys ())} "
96
- )
97
- cls = ATTENTION_REGISTRY [args .attention_type ]
98
- self .attention = cls (args , layer_id , rope )
101
+ self .attention = attention
99
102
if args .moe :
100
103
self .block_sparse_moe = MOEFeedForward (args )
101
104
else :
102
105
self .feed_forward = FeedForward (args )
103
106
self .attention_norm = RMSNorm (args .dim , eps = args .norm_eps )
104
107
self .ffn_norm = RMSNorm (args .dim , eps = args .norm_eps )
105
108
109
+ @classmethod
110
+ def from_type (cls , layer_id , args , rope ) -> "TransformerBlock" :
111
+ """
112
+ Create a TransformerBlock with the legacy constructor.
113
+ Args:
114
+ layer_id (int): the index of the layer.
115
+ args (ModelArgs): model configuration parameters.
116
+ rope (Rope): the rope object to use for rotary embeddings.
117
+ """
118
+ if args .attention_type not in ATTENTION_REGISTRY :
119
+ raise ValueError (
120
+ f"Unknown attention type: { args .attention_type } . "
121
+ f"Available: { list (ATTENTION_REGISTRY .keys ())} "
122
+ )
123
+ cls = ATTENTION_REGISTRY [args .attention_type ]
124
+ attention = cls (args , layer_id , rope )
125
+ return TransformerBlock (args , attention )
126
+
106
127
def forward (self , x , freqs_cos , freqs_sin , attn_options : ForwardOptions ): # x: 1xN
107
128
h , attn_options_update = self .attention .forward (
108
129
self .attention_norm (x ), freqs_cos , freqs_sin , ** attn_options
@@ -117,7 +138,15 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
117
138
118
139
119
140
class Transformer (nn .Module ):
120
- def __init__ (self , params : ModelArgs ):
141
+ def __init__ (self , params : ModelArgs , layers : nn .ModuleList , rope : Rope ):
142
+ """
143
+ Transformer model.
144
+ Args:
145
+ params (ModelArgs): model configuration parameters.
146
+ layers (nn.ModuleList): list of transformer blocks - see the
147
+ `TransformerBlock` type above.
148
+ rope (Rope): the rope object to use for rotary embeddings.
149
+ """
121
150
super ().__init__ ()
122
151
self .params = params
123
152
self .vocab_size = params .vocab_size
@@ -130,10 +159,8 @@ def __init__(self, params: ModelArgs):
130
159
if self .apply_embedding
131
160
else None
132
161
)
133
- self .rope = Rope (params )
134
- self .layers = torch .nn .ModuleList ()
135
- for layer_id in range (params .n_layers ):
136
- self .layers .append (TransformerBlock (layer_id , params , self .rope ))
162
+ self .layers = layers
163
+ self .rope = rope
137
164
self .norm = RMSNorm (params .dim , eps = params .norm_eps )
138
165
self .output = (
139
166
nn .Linear (params .dim , params .vocab_size , bias = False )
@@ -212,3 +239,23 @@ def forward(
212
239
return logits , attn_options_update
213
240
214
241
return logits
242
+
243
+
244
+ def construct_transformer (model_args : ModelArgs ) -> Transformer :
245
+ """
246
+ Construct a Transformer model from the given model arguments.
247
+ """
248
+ rope = Rope (model_args )
249
+ if model_args .attention_type not in ATTENTION_REGISTRY :
250
+ raise ValueError (
251
+ f"Unknown attention type: { model_args .attention_type } . "
252
+ f"Available: { list (ATTENTION_REGISTRY .keys ())} "
253
+ )
254
+ layers = torch .nn .ModuleList ()
255
+ cls = ATTENTION_REGISTRY [model_args .attention_type ]
256
+ for layer_id in range (model_args .n_layers ):
257
+ attention = cls (model_args , layer_id , rope )
258
+ transformer_block = TransformerBlock (model_args , attention )
259
+ layers .append (transformer_block )
260
+
261
+ return Transformer (model_args , layers , rope )
0 commit comments