1
+ from __future__ import annotations
2
+ from typing_extensions import override
3
+ import torch
4
+ from .config import Config , no_default
5
+ from .model import Model
6
+ from ..util .rope import RopeSettings , RopeStyle
7
+ from ..modules import RMSNorm , Embedding , TransformerBlock , Attention , GatedMLP , Linear
8
+ from ..modules .attn import prepare_for_attn
9
+
10
+ class DeciLMConfig (Config ):
11
+
12
+ def __init__ (
13
+ self ,
14
+ directory : str ,
15
+ ** kwargs ,
16
+ ):
17
+ super ().__init__ (
18
+ directory ,
19
+ kwargs .get ("arch_string" , "DeciLMForCausalLM" ),
20
+ DeciLMModel ,
21
+ ** kwargs
22
+ )
23
+
24
+ # Global attention params
25
+ self .head_dim = self .read_cfg (int , "head_dim" , None )
26
+ self .hidden_size = self .read_cfg (int , "hidden_size" , no_default )
27
+ self .num_q_heads = self .read_cfg (int , "num_attention_heads" , no_default )
28
+
29
+ if not self .head_dim :
30
+ self .head_dim = self .hidden_size // self .num_q_heads
31
+
32
+ # MLP params
33
+ self .assert_cfg (str , "hidden_act" , "silu" , True )
34
+
35
+ # Norms
36
+ self .rms_norm_eps = self .read_cfg (float , "rms_norm_eps" , no_default )
37
+
38
+ # Layers
39
+ self .num_hidden_layers = self .read_cfg (int , "num_hidden_layers" , no_default )
40
+ self .tie_word_embeddings = self .read_cfg (bool , "tie_word_embeddings" , False )
41
+
42
+ # RoPE
43
+ self .rope_settings = self .read_rope_settings_default (RopeStyle .NEOX )
44
+
45
+ # Block configs
46
+ self .block_configs = self .read_cfg (list , "block_configs" , no_default )
47
+ assert len (self .block_configs ) == self .num_hidden_layers , \
48
+ "Number of hidden layers does not match length of block_configs list"
49
+
50
+
51
+ class DeciLMModel (Model ):
52
+
53
+ def __init__ (
54
+ self ,
55
+ config : DeciLMConfig ,
56
+ ** kwargs
57
+ ):
58
+ super ().__init__ (config , ** kwargs )
59
+
60
+ self .modules += [
61
+ Embedding (
62
+ config = config ,
63
+ key = "model.embed_tokens" ,
64
+ vocab_size = config .vocab_size ,
65
+ hidden_size = config .hidden_size ,
66
+ )
67
+ ]
68
+
69
+ self .first_block_idx = len (self .modules )
70
+ self .last_kv_module_idx = 0
71
+ cache_layer_idx = 0
72
+
73
+ for idx , cfg in enumerate (config .block_configs ):
74
+ cfg_attn = cfg ["attention" ]
75
+ cfg_ffn = cfg ["ffn" ]
76
+
77
+ if cfg_attn .get ("no_op" ):
78
+ attn_norm = None
79
+ attn = None
80
+ else :
81
+ assert not cfg_attn .get ("num_sink_tokens" ), "DeciLM: num_sink_tokens not supported"
82
+ assert not cfg_attn .get ("replace_with_linear" ), "DeciLM: replace_with_linear not supported"
83
+ assert not cfg_attn .get ("sparsify" ), "DeciLM: sparsify not supported"
84
+ assert not cfg_attn .get ("unshifted_sink" ), "DeciLM: unshifted_sink not supported"
85
+ assert not cfg_attn .get ("use_prefill_window_in_sink_attention" ), \
86
+ "DeciLM: use_prefill_window_in_sink_attention not supported"
87
+ attn_norm = RMSNorm (
88
+ config = config ,
89
+ key = f"model.layers.{ idx } .input_layernorm" ,
90
+ rms_norm_eps = config .rms_norm_eps ,
91
+ )
92
+ attn = Attention (
93
+ config = config ,
94
+ key = f"model.layers.{ idx } .self_attn" ,
95
+ layer_idx = cache_layer_idx ,
96
+ hidden_size = config .hidden_size ,
97
+ head_dim = config .head_dim ,
98
+ num_q_heads = config .num_q_heads ,
99
+ num_kv_heads = cfg_attn ["n_heads_in_group" ],
100
+ rope_settings = config .rope_settings ,
101
+ sm_scale = None ,
102
+ key_q = "q_proj" ,
103
+ key_k = "k_proj" ,
104
+ key_v = "v_proj" ,
105
+ key_o = "o_proj" ,
106
+ qmap = "block.attn" ,
107
+ )
108
+ cache_layer_idx += 1
109
+ self .last_kv_module_idx = len (self .modules )
110
+
111
+ if cfg_ffn .get ("no_op" ):
112
+ mlp_norm = None
113
+ mlp = None
114
+ else :
115
+ assert not cfg_ffn .get ("replace_with_linear" ), "DeciLM: replace_with_linear not supported"
116
+ assert not cfg_ffn .get ("sparsify" ), "DeciLM: sparsify not supported"
117
+ mlp_norm = RMSNorm (
118
+ config = config ,
119
+ key = f"model.layers.{ idx } .post_attention_layernorm" ,
120
+ rms_norm_eps = config .rms_norm_eps ,
121
+ )
122
+ interm_size = int (2 * cfg_ffn ["ffn_mult" ] * config .hidden_size / 3 )
123
+ interm_size = ((interm_size + 255 ) // 256 ) * 256
124
+ mlp = GatedMLP (
125
+ config = config ,
126
+ key = f"model.layers.{ idx } .mlp" ,
127
+ hidden_size = config .hidden_size ,
128
+ intermediate_size = interm_size ,
129
+ key_up = "up_proj" ,
130
+ key_gate = "gate_proj" ,
131
+ key_down = "down_proj" ,
132
+ qmap = "block.mlp" ,
133
+ )
134
+
135
+ self .modules += [
136
+ TransformerBlock (
137
+ config = config ,
138
+ key = f"model.layers.{ idx } " ,
139
+ attn_norm = attn_norm ,
140
+ attn = attn ,
141
+ mlp_norm = mlp_norm ,
142
+ mlp = mlp ,
143
+ )
144
+ ]
145
+
146
+ head_alt_key = None
147
+ if config .tie_word_embeddings and not self .config .stc .has_tensor ("lm_head" ):
148
+ head_alt_key = "model.embed_tokens"
149
+
150
+ self .modules += [
151
+ RMSNorm (
152
+ config = config ,
153
+ key = "model.norm" ,
154
+ rms_norm_eps = config .rms_norm_eps ,
155
+ out_dtype = torch .half ,
156
+ ),
157
+ Linear (
158
+ config = config ,
159
+ key = "lm_head" ,
160
+ qbits_key = "head_bits" ,
161
+ alt_key = head_alt_key ,
162
+ in_features = config .hidden_size ,
163
+ out_features = config .vocab_size ,
164
+ qmap = "block" ,
165
+ caps = {"logits_output" : True }
166
+ )
167
+ ]
168
+
169
+ self .logit_layer_idx = len (self .modules ) - 1
170
+
171
+
172
+ @override
173
+ def prepare_inputs (self , input_ids : torch .Tensor , params : dict ) -> torch .Tensor :
174
+ params ["input_ids" ] = input_ids
175
+ input_ids = prepare_for_attn (input_ids , params )
176
+ return input_ids
0 commit comments