@@ -130,14 +130,14 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False):
130
130
self .modules .append (ExLlamaV2Embedding (self , "model.embed_tokens" ))
131
131
self .modules_dict [self .modules [- 1 ].key ] = self .modules [- 1 ]
132
132
133
- for layer_idx in range (self .config .num_hidden_layers ):
133
+ for layer_list in range (self .config .num_hidden_layers ):
134
134
135
- self .modules .append (ExLlamaV2Attention (self , f"model.layers.{ layer_idx } " , layer_idx ))
135
+ self .modules .append (ExLlamaV2Attention (self , f"model.layers.{ layer_list } " , layer_list ))
136
136
for m in self .modules [- 1 ].submodules : self .modules_dict [m .key ] = m
137
137
if self .config .architecture == "Mixtral" :
138
- self .modules .append (ExLlamaV2MoEMLP (self , f"model.layers.{ layer_idx } " , layer_idx ))
138
+ self .modules .append (ExLlamaV2MoEMLP (self , f"model.layers.{ layer_list } " , layer_list ))
139
139
else :
140
- self .modules .append (ExLlamaV2MLP (self , f"model.layers.{ layer_idx } " , layer_idx ))
140
+ self .modules .append (ExLlamaV2MLP (self , f"model.layers.{ layer_list } " , layer_list ))
141
141
for m in self .modules [- 1 ].submodules : self .modules_dict [m .key ] = m
142
142
143
143
@@ -150,15 +150,40 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False):
150
150
151
151
# Find last layer that affects k/v cache
152
152
153
- layer_idx = len (self .modules )
153
+ layer_list = len (self .modules )
154
154
while True :
155
- layer_idx -= 1
156
- if isinstance (self .modules [layer_idx ], ExLlamaV2Attention ):
155
+ layer_list -= 1
156
+ if isinstance (self .modules [layer_list ], ExLlamaV2Attention ):
157
157
break
158
158
159
- self .last_kv_layer_idx = layer_idx
159
+ self .last_kv_layer_idx = layer_list
160
160
161
161
162
+ if hasattr (config , 'repeats' ):
163
+ self .layers = []
164
+
165
+ def listLeftIndex (alist , value ):
166
+ if value == 0 :
167
+ return 0
168
+ return alist .index (str (value ))
169
+
170
+ def listRightIndex (alist , value ):
171
+ if value > len (alist ):
172
+ return - 1
173
+ return len (alist ) - alist [- 1 ::- 1 ].index (str (value )) - 1
174
+
175
+ layer_list = [layer .key .split ("." )[- 1 ] for layer in self .modules ]
176
+
177
+ for interval in config .repeats :
178
+ start_idx = listLeftIndex (layer_list , interval [0 ])
179
+ end_idx = listRightIndex (layer_list , interval [1 ])
180
+ self .layers .extend (list (range (start_idx , end_idx + 1 )))
181
+ self .layers .extend (list (range (listRightIndex (layer_list , config .repeats [- 1 ][1 ]), len (layer_list ))))
182
+
183
+ # If we have create a Frankenmerge, lets print it to verify!
184
+ for layer in self .layers :
185
+ print (layer , self .modules [layer ].key )
186
+
162
187
def set_device_map (self , allocation , embed_cpu = True ):
163
188
164
189
self .cache_map = {}
@@ -582,6 +607,23 @@ def _forward(self,
582
607
return_last_state = False ,
583
608
position_offsets = None ):
584
609
610
+ def process_module (module , x , last_state ):
611
+ device = _torch_device (module .device_idx )
612
+
613
+ if idx == self .head_layer_idx :
614
+ if last_id_only and return_last_state :
615
+ x = x .narrow (- 2 , - 1 , 1 )
616
+ last_state = x
617
+ elif last_id_only :
618
+ x = x .narrow (- 2 , - 1 , 1 )
619
+ elif return_last_state :
620
+ last_state = x .narrow (- 2 , - 1 , 1 )
621
+
622
+ x = safe_move_tensor (x , device )
623
+ x = module .forward (x , cache = cache , attn_params = attn_params , past_len = past_len , loras = loras )
624
+
625
+ return x , last_state
626
+
585
627
batch_size , seq_len = input_ids .shape
586
628
past_len = 0
587
629
if cache is not None :
@@ -596,27 +638,19 @@ def _forward(self,
596
638
attn_params = ExLlamaV2Attention .Params (batch_size , seq_len , past_len , input_mask , position_offsets )
597
639
last_state = None
598
640
599
- for idx , module in enumerate (self .modules ):
600
-
601
- device = _torch_device (module .device_idx )
602
-
603
- # Onward
604
-
605
- if idx == self .head_layer_idx :
606
- if last_id_only and return_last_state :
607
- x = x .narrow (- 2 , - 1 , 1 )
608
- last_state = x
609
- elif last_id_only :
610
- x = x .narrow (- 2 , - 1 , 1 )
611
- elif return_last_state :
612
- last_state = x .narrow (- 2 , - 1 , 1 )
613
-
614
- x = safe_move_tensor (x , device )
615
- x = module .forward (x , cache = cache , attn_params = attn_params , past_len = past_len , loras = loras )
616
-
617
- if preprocess_only and idx == self .last_kv_layer_idx :
618
- x = None
619
- break
641
+ if hasattr (self , 'layers' ):
642
+ for i , idx in enumerate (self .layers ):
643
+ module = self .modules [idx ]
644
+ x , last_state = process_module (module , x , last_state )
645
+ if preprocess_only and idx == self .last_kv_layer_idx :
646
+ x = None
647
+ break
648
+ else :
649
+ for idx , module in enumerate (self .modules ):
650
+ x , last_state = process_module (module , x , last_state )
651
+ if preprocess_only and idx == self .last_kv_layer_idx :
652
+ x = None
653
+ break
620
654
621
655
# Advance cache
622
656
0 commit comments