Skip to content

Commit 15709e0

Browse files
metascroymalfet
authored andcommitted
initial gguf stuff
1 parent 6ee5a45 commit 15709e0

File tree

3 files changed

+527
-0
lines changed

3 files changed

+527
-0
lines changed

gguf_util/convert_from_gguf.py

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import argparse
9+
10+
import copy
11+
import sys
12+
from dataclasses import dataclass
13+
from pathlib import Path
14+
from typing import Any, Mapping
15+
import logging
16+
17+
import gguf
18+
19+
import torch
20+
import torch.nn as nn
21+
22+
from gguf import GGUFValueType, ReaderTensor
23+
24+
wd = Path(__file__).parent.parent.resolve()
25+
sys.path.append(str(wd))
26+
from model import ModelArgs, Transformer
27+
28+
from typing import Set
29+
30+
logger: logging.Logger = logging.getLogger(__name__)
31+
32+
33+
@dataclass
34+
class AttentionArgs:
35+
head_count: int
36+
head_count_kv: int
37+
layer_norm_rms_epsilon: float
38+
39+
40+
@dataclass
41+
class RopeArgs:
42+
dimension_count: int | None = None
43+
freq_base: float | None = None
44+
45+
46+
@dataclass
47+
class GGUFModelArgs:
48+
arch: str
49+
embedding_length: int
50+
block_count: int
51+
feed_forward_length: int
52+
vocab_size: int
53+
attention: AttentionArgs
54+
rope: RopeArgs
55+
56+
57+
@dataclass
58+
class GGUFWeights:
59+
tensors: list[ReaderTensor]
60+
61+
62+
def _create_pt_model(
63+
gguf_model_args: GGUFModelArgs,
64+
) -> nn.Module:
65+
llama_model_args = ModelArgs(
66+
dim=gguf_model_args.embedding_length,
67+
n_layer=gguf_model_args.block_count,
68+
n_head=gguf_model_args.attention.head_count,
69+
n_local_heads=gguf_model_args.attention.head_count_kv,
70+
vocab_size=gguf_model_args.vocab_size,
71+
norm_eps=gguf_model_args.attention.layer_norm_rms_epsilon,
72+
intermediate_size=gguf_model_args.feed_forward_length,
73+
)
74+
pt_model = Transformer(llama_model_args)
75+
pt_model.eval()
76+
return pt_model
77+
78+
79+
_name_replacements = [
80+
("blk", "layers"),
81+
("token_embd", "tok_embeddings"),
82+
("attn_q", "attention.wq"),
83+
("attn_k", "attention.wk"),
84+
("attn_v", "attention.wv"),
85+
("attn_output", "attention.wo"),
86+
("attn_norm", "attention_norm"),
87+
("output_norm.weight", "norm.weight"),
88+
("ffn_down", "feed_forward.w2"),
89+
("ffn_gate", "feed_forward.w1"),
90+
("ffn_up", "feed_forward.w3"),
91+
]
92+
93+
94+
def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str:
95+
result = copy.deepcopy(gguf_name)
96+
for gguf_string, replacement in _name_replacements:
97+
result = result.replace(gguf_string, replacement)
98+
return result
99+
100+
101+
def _fqn_lookup(fqn: str, module: torch.nn.Module) -> Any:
102+
if fqn == "":
103+
return module
104+
atoms = fqn.split(".")
105+
curr = module
106+
for a in atoms:
107+
curr = getattr(curr, a)
108+
return curr
109+
110+
111+
def _fqn_down(fqn: str, name: str) -> str:
112+
if fqn == "":
113+
return name
114+
return f"{fqn}.{name}"
115+
116+
117+
def _fqn_up(fqn: str) -> str:
118+
atoms = fqn.split(".")
119+
if len(atoms) == 1:
120+
return ""
121+
return ".".join(atoms[0:-1])
122+
123+
124+
def _fqn_last(fqn: str) -> str:
125+
atoms = fqn.split(".")
126+
return atoms[-1]
127+
128+
129+
def _load_weights(pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor]) -> None:
130+
# state_dict pass
131+
state_dict = {}
132+
for fqn, _ model.state_dict():
133+
if fqn not in weight_map:
134+
continue
135+
tensor = weight_map[fqn]
136+
137+
if tensor.tensor_type in (GGMLQuantizationType.F32, GGMLQuantizationType.F16):
138+
reversed_shape = tensor.shape[::-1]
139+
new_tensor = tensor.data.reshape(reversed_shape)
140+
state_dict[fqn] = torch.from_numpy(new_tensor)
141+
elif tensor.tensor_type == GGMLQuantizationType.Q4_0 and tensor.name == "token_embd.weight":
142+
unpacked = to_float(torch.from_numpy(tensor.data.reshape(-1, 18)))
143+
state_dict[fqn] = unpacked.reshape(
144+
pt_model.params.vocab_size, pt_model.params.dim
145+
)
146+
147+
# allow partial loading
148+
pt_model.load_state_dict(state_dict, strict=False)
149+
150+
# parameter pass
151+
for fqn, param in pt_model.named_parameters():
152+
if fqn not in weight_map:
153+
continue
154+
tensor = weight_map[fqn]
155+
156+
if tensor.tensor_type == GGMLQuantizationType.Q4_0:
157+
parent = _fqn_lookup(_fqn_up(fqn), pt_model)
158+
if isinstance(parent, torch.nn.Linear) and _fqn_last(fqn) == "weight":
159+
print(fqn, tensor.shape, tensor.data.shape, parent.weight.shape)
160+
packed = torch.from_numpy(tensor.data).reshape(-1, 18)
161+
scale = torch.tensor(_unpack_two_uint8(packed[:, :2]), dtype=torch.float16)
162+
parent.weight = torch.nn.Parameter(
163+
GGMLInt4LinearWeight(packed, scale, parent.weight.shape)
164+
)
165+
166+
# TODO: add some check that every weight was loaded
167+
# TODO: do we need to add special layers.{id}.attention.mask logic
168+
169+
170+
def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]:
171+
metadata: dict[str, Any] = {}
172+
173+
for idx, field in enumerate(reader.fields.values()):
174+
val = None
175+
if field.types[:1] == [GGUFValueType.ARRAY]:
176+
itype = field.types[-1]
177+
if itype == GGUFValueType.STRING:
178+
val = [
179+
str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data
180+
]
181+
else:
182+
val = [pv for idx in field.data for pv in field.parts[idx].tolist()]
183+
elif field.types[0] == GGUFValueType.STRING:
184+
val = str(bytes(field.parts[-1]), encoding="utf-8")
185+
else:
186+
val = field.parts[-1].tolist()[0]
187+
188+
metadata[field.name] = val
189+
190+
return metadata
191+
192+
193+
def _build_model_args(metadata: dict[str, Any]) -> GGUFModelArgs:
194+
arch = metadata["general.architecture"]
195+
assert arch == "llama", "Only LLaMa models are supported by this converter."
196+
197+
gguf_ft = metadata["general.file_type"]
198+
# ALL_F32 or MOSTLY_F16
199+
assert (
200+
gguf_ft == 0 or gguf_ft == 1
201+
), "Only fp32 or fp16 are supported by this converter."
202+
203+
return GGUFModelArgs(
204+
arch=arch,
205+
embedding_length=metadata[f"{arch}.embedding_length"],
206+
block_count=metadata[f"{arch}.block_count"],
207+
feed_forward_length=metadata[f"{arch}.feed_forward_length"],
208+
vocab_size=len(metadata["tokenizer.ggml.tokens"]),
209+
attention=AttentionArgs(
210+
head_count=metadata[f"{arch}.attention.head_count"],
211+
head_count_kv=metadata[f"{arch}.attention.head_count_kv"],
212+
layer_norm_rms_epsilon=metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
213+
),
214+
rope=RopeArgs(
215+
freq_base=metadata.get(f"{arch}.rope.freq_base", None),
216+
dimension_count=metadata.get(f"{arch}.rope.dimension_count", None),
217+
),
218+
)
219+
220+
221+
def load_gguf_file(gguf_file: str) -> (GGUFModelArgs, GGUFWeights):
222+
"""
223+
Load a GGUF file and return the model arguments and weights.
224+
"""
225+
if not Path(gguf_file).is_file():
226+
raise ValueError(f"Could not find file {gguf_file}")
227+
228+
reader = gguf.GGUFReader(gguf_file, "r")
229+
230+
# Step 1: Build GGUFModelArgs
231+
metadata = _get_metadata(reader)
232+
model_args = _build_model_args(metadata)
233+
234+
# Step 2: Build GGUFWeights
235+
gguf_weights = GGUFWeights(tensors=reader.tensors)
236+
237+
return (model_args, gguf_weights)
238+
239+
240+
def convert_from_gguf(gguf_file: str) -> torch.nn.Module:
241+
242+
gguf_model_args, gguf_weights = load_gguf_file(gguf_file)
243+
assert (
244+
gguf_model_args.arch == "llama"
245+
), "Only LLaMa models are supported by this converter."
246+
247+
248+
pt_model = _create_pt_model(gguf_model_args)
249+
250+
# map from fqn in pt_model to gguf tensor
251+
weight_map = {
252+
_convert_gguf_tensor_name_to_llama_nn(tensor.name): tensor
253+
for tensor in gguf_weights.tensors
254+
}
255+
_load_weights(pt_model, weight_map)
256+
257+
return pt_model

0 commit comments

Comments
 (0)