Skip to content

Commit a4e75d4

Browse files
committed
Updated to match the format of Mergekit
1 parent 3afecf7 commit a4e75d4

File tree

2 files changed

+26
-28
lines changed

2 files changed

+26
-28
lines changed

exllamav2/model.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -158,32 +158,24 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False):
158158

159159
self.last_kv_layer_idx = layer_list
160160

161-
162161
if hasattr(config, 'repeats'):
163-
self.layers_list = []
164-
165-
def listLeftIndex(alist, value):
166-
if value == 0:
167-
return 0
168-
return alist.index(str(value))
169-
170-
def listRightIndex(alist, value):
171-
if value > len(alist):
172-
return -1
173-
return len(alist) - alist[-1::-1].index(str(value)) -1
174-
175-
layer_list = [layer.key.split(".")[-1] for layer in self.modules]
176-
177-
for interval in config.repeats:
178-
start_idx = listLeftIndex(layer_list, interval[0])
179-
end_idx = listRightIndex(layer_list, interval[1])
180-
self.layers_list.extend(list(range(start_idx, end_idx + 1)))
181-
self.layers_list.extend(list(range(listRightIndex(layer_list, config.repeats[-1][1]), len(layer_list))))
182-
183-
# If we have create a Frankenmerge, lets print it to verify!
184-
print("Frankenstein Layers list:")
185-
for i, layer in enumerate(self.layers_list):
186-
print(i, self.modules[layer].key)
162+
embedTokenLayers = 1
163+
transformerSublayers = 2
164+
layer_arrangement = [list(range(*interval)) for interval in config.repeats]
165+
layer_arrangement = [item for sublist in layer_arrangement for item in sublist]
166+
LayeredModules = self.modules
167+
168+
169+
self.modules = LayeredModules[:embedTokenLayers]
170+
for idx in layer_arrangement:
171+
self.modules += LayeredModules[idx*transformerSublayers + embedTokenLayers : idx*transformerSublayers + transformerSublayers + embedTokenLayers]
172+
self.modules += LayeredModules[-2:]
173+
self.head_layer_idx = len(self.modules) -1
174+
self.last_kv_layer_idx = len(self.modules) -4
175+
176+
for i, m in enumerate(self.modules):
177+
print(i, m.key)
178+
187179

188180
def set_device_map(self, allocation, embed_cpu = True):
189181

exllamav2/model_init.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,17 @@ def parse_tuple_list(string):
6565
try:
6666
# Safely evaluate the string as a Python literal (list of tuples)
6767
tuple_list = ast.literal_eval(string)
68+
69+
# Ensure all elements in the list are tuples
6870
if not all(isinstance(item, tuple) for item in tuple_list):
69-
raise ValueError
70-
return tuple_list
71+
raise ValueError("All elements must be tuples")
72+
73+
# Convert tuple elements to integers
74+
int_tuple_list = [tuple(int(x) for x in item) for item in tuple_list]
75+
76+
return int_tuple_list
7177
except:
72-
raise argparse.ArgumentTypeError("Input must be a list of tuples")
78+
raise argparse.ArgumentTypeError("Input must be a valid list of tuples with integer elements")
7379

7480

7581
def init(args, quiet = False, allow_auto_split = False, skip_load = False):

0 commit comments

Comments
 (0)