Skip to content

Commit baa1efe

Browse files
committed
Allow layer repeats
1 parent 24c686b commit baa1efe

File tree

2 files changed

+77
-31
lines changed

2 files changed

+77
-31
lines changed

exllamav2/model.py

Lines changed: 63 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,14 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False):
130130
self.modules.append(ExLlamaV2Embedding(self, "model.embed_tokens"))
131131
self.modules_dict[self.modules[-1].key] = self.modules[-1]
132132

133-
for layer_idx in range(self.config.num_hidden_layers):
133+
for layer_list in range(self.config.num_hidden_layers):
134134

135-
self.modules.append(ExLlamaV2Attention(self, f"model.layers.{layer_idx}", layer_idx))
135+
self.modules.append(ExLlamaV2Attention(self, f"model.layers.{layer_list}", layer_list))
136136
for m in self.modules[-1].submodules: self.modules_dict[m.key] = m
137137
if self.config.architecture == "Mixtral":
138-
self.modules.append(ExLlamaV2MoEMLP(self, f"model.layers.{layer_idx}", layer_idx))
138+
self.modules.append(ExLlamaV2MoEMLP(self, f"model.layers.{layer_list}", layer_list))
139139
else:
140-
self.modules.append(ExLlamaV2MLP(self, f"model.layers.{layer_idx}", layer_idx))
140+
self.modules.append(ExLlamaV2MLP(self, f"model.layers.{layer_list}", layer_list))
141141
for m in self.modules[-1].submodules: self.modules_dict[m.key] = m
142142

143143

@@ -150,15 +150,40 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False):
150150

151151
# Find last layer that affects k/v cache
152152

153-
layer_idx = len(self.modules)
153+
layer_list = len(self.modules)
154154
while True:
155-
layer_idx -= 1
156-
if isinstance(self.modules[layer_idx], ExLlamaV2Attention):
155+
layer_list -= 1
156+
if isinstance(self.modules[layer_list], ExLlamaV2Attention):
157157
break
158158

159-
self.last_kv_layer_idx = layer_idx
159+
self.last_kv_layer_idx = layer_list
160160

161161

162+
if hasattr(config, 'repeats'):
163+
self.layers = []
164+
165+
def listLeftIndex(alist, value):
166+
if value == 0:
167+
return 0
168+
return alist.index(str(value))
169+
170+
def listRightIndex(alist, value):
171+
if value > len(alist):
172+
return -1
173+
return len(alist) - alist[-1::-1].index(str(value)) -1
174+
175+
layer_list = [layer.key.split(".")[-1] for layer in self.modules]
176+
177+
for interval in config.repeats:
178+
start_idx = listLeftIndex(layer_list, interval[0])
179+
end_idx = listRightIndex(layer_list, interval[1])
180+
self.layers.extend(list(range(start_idx, end_idx + 1)))
181+
self.layers.extend(list(range(listRightIndex(layer_list, config.repeats[-1][1]), len(layer_list))))
182+
183+
# If we have create a Frankenmerge, lets print it to verify!
184+
for layer in self.layers:
185+
print(layer, self.modules[layer].key)
186+
162187
def set_device_map(self, allocation, embed_cpu = True):
163188

164189
self.cache_map = {}
@@ -582,6 +607,23 @@ def _forward(self,
582607
return_last_state = False,
583608
position_offsets = None):
584609

610+
def process_module(module, x, last_state):
611+
device = _torch_device(module.device_idx)
612+
613+
if idx == self.head_layer_idx:
614+
if last_id_only and return_last_state:
615+
x = x.narrow(-2, -1, 1)
616+
last_state = x
617+
elif last_id_only:
618+
x = x.narrow(-2, -1, 1)
619+
elif return_last_state:
620+
last_state = x.narrow(-2, -1, 1)
621+
622+
x = safe_move_tensor(x, device)
623+
x = module.forward(x, cache=cache, attn_params=attn_params, past_len=past_len, loras=loras)
624+
625+
return x, last_state
626+
585627
batch_size, seq_len = input_ids.shape
586628
past_len = 0
587629
if cache is not None:
@@ -596,27 +638,19 @@ def _forward(self,
596638
attn_params = ExLlamaV2Attention.Params(batch_size, seq_len, past_len, input_mask, position_offsets)
597639
last_state = None
598640

599-
for idx, module in enumerate(self.modules):
600-
601-
device = _torch_device(module.device_idx)
602-
603-
# Onward
604-
605-
if idx == self.head_layer_idx:
606-
if last_id_only and return_last_state:
607-
x = x.narrow(-2, -1, 1)
608-
last_state = x
609-
elif last_id_only:
610-
x = x.narrow(-2, -1, 1)
611-
elif return_last_state:
612-
last_state = x.narrow(-2, -1, 1)
613-
614-
x = safe_move_tensor(x, device)
615-
x = module.forward(x, cache = cache, attn_params = attn_params, past_len = past_len, loras = loras)
616-
617-
if preprocess_only and idx == self.last_kv_layer_idx:
618-
x = None
619-
break
641+
if hasattr(self, 'layers'):
642+
for i, idx in enumerate(self.layers):
643+
module = self.modules[idx]
644+
x, last_state = process_module(module, x, last_state)
645+
if preprocess_only and idx == self.last_kv_layer_idx:
646+
x = None
647+
break
648+
else:
649+
for idx, module in enumerate(self.modules):
650+
x, last_state = process_module(module, x, last_state)
651+
if preprocess_only and idx == self.last_kv_layer_idx:
652+
x = None
653+
break
620654

621655
# Advance cache
622656

exllamav2/model_init.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
import argparse, sys, os, glob
2+
import argparse, sys, os, glob, ast
33

44
from exllamav2 import(
55
ExLlamaV2,
@@ -17,6 +17,7 @@ def add_args(parser):
1717
parser.add_argument("-nfa", "--no_flash_attn", action = "store_true", help = "Disable Flash Attention")
1818
parser.add_argument("-lm", "--low_mem", action = "store_true", help = "Enable VRAM optimizations, potentially trading off speed")
1919
parser.add_argument("-ept", "--experts_per_token", type = int, help = "Override MoE model's default number of experts per token")
20+
parser.add_argument("--repeats", type=parse_tuple_list, help="List of tuples of the layers to repeat")
2021

2122

2223
def print_options(args):
@@ -60,6 +61,16 @@ def check_args(args):
6061
print(f" ## Error: Cannot find {filename} in {args.model_dir}")
6162
sys.exit()
6263

64+
def parse_tuple_list(string):
65+
try:
66+
# Safely evaluate the string as a Python literal (list of tuples)
67+
tuple_list = ast.literal_eval(string)
68+
if not all(isinstance(item, tuple) for item in tuple_list):
69+
raise ValueError
70+
return tuple_list
71+
except:
72+
raise argparse.ArgumentTypeError("Input must be a list of tuples")
73+
6374

6475
def init(args, quiet = False, allow_auto_split = False, skip_load = False):
6576

@@ -76,7 +87,8 @@ def init(args, quiet = False, allow_auto_split = False, skip_load = False):
7687
if args.rope_alpha: config.scale_alpha_value = args.rope_alpha
7788
config.no_flash_attn = args.no_flash_attn
7889
if args.experts_per_token: config.num_experts_per_token = args.experts_per_token
79-
90+
if args.repeats: config.repeats = args.repeats
91+
8092
# Set low-mem options
8193

8294
if args.low_mem: config.set_low_mem()

0 commit comments

Comments
 (0)