Skip to content

Commit f48f51d

Browse files
committed
Merge branch 'master' into xsn/convert_gguf_qwen2vl
2 parents f7260c2 + 3e168be commit f48f51d

File tree

5 files changed

+76
-45
lines changed

5 files changed

+76
-45
lines changed

common/arg.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1948,6 +1948,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
19481948
params.sampling.grammar = json_schema_to_grammar(json::parse(value));
19491949
}
19501950
).set_sparam());
1951+
add_opt(common_arg(
1952+
{"-jf", "--json-schema-file"}, "FILE",
1953+
"File containing a JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead",
1954+
[](common_params & params, const std::string & value) {
1955+
std::ifstream file(value);
1956+
if (!file) {
1957+
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
1958+
}
1959+
std::string schema;
1960+
std::copy(
1961+
std::istreambuf_iterator<char>(file),
1962+
std::istreambuf_iterator<char>(),
1963+
std::back_inserter(schema)
1964+
);
1965+
params.sampling.grammar = json_schema_to_grammar(json::parse(schema));
1966+
}
1967+
).set_sparam());
19511968
add_opt(common_arg(
19521969
{"--pooling"}, "{none,mean,cls,last,rank}",
19531970
"pooling type for embeddings, use model default if unspecified",

convert_hf_to_gguf.py

Lines changed: 57 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from hashlib import sha256
1717
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
1818
from itertools import chain
19+
from transformers import AutoConfig
1920

2021
import math
2122
import numpy as np
@@ -66,8 +67,6 @@ class ModelBase:
6667
part_names: list[str]
6768
is_safetensors: bool
6869
hparams: dict[str, Any]
69-
block_count: int
70-
tensor_map: gguf.TensorNameMap
7170
tensor_names: set[str] | None
7271
gguf_writer: gguf.GGUFWriter
7372
model_name: str | None
@@ -78,6 +77,10 @@ class ModelBase:
7877
# subclasses should define this!
7978
model_arch: gguf.MODEL_ARCH
8079

80+
# subclasses should initialize this!
81+
block_count: int
82+
tensor_map: gguf.TensorNameMap
83+
8184
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
8285
use_temp_file: bool = False, eager: bool = False,
8386
metadata_override: Path | None = None, model_name: str | None = None,
@@ -113,8 +116,6 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
113116
if not self.is_safetensors:
114117
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
115118
self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams
116-
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
117-
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
118119
self.tensor_names = None
119120
self.metadata_override = metadata_override
120121
self.model_name = model_name
@@ -417,15 +418,13 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]
417418

418419
@staticmethod
419420
def load_hparams(dir_model: Path):
420-
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
421-
hparams = json.load(f)
422-
architectures = hparams.get("architectures")
423-
if "text_config" in hparams:
424-
hparams = {**hparams, **hparams["text_config"]}
425-
if architectures is not None:
426-
# preserve "architectures" from root level config
427-
hparams["architectures"] = architectures
428-
return hparams
421+
try:
422+
return AutoConfig.from_pretrained(dir_model).to_dict()
423+
except Exception as e:
424+
logger.warning(f"Failed to load model config from {dir_model}: {e}")
425+
logger.warning("Trying to load config.json instead")
426+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
427+
return json.load(f)
429428

430429
@classmethod
431430
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
@@ -454,6 +453,23 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type
454453

455454

456455
class TextModel(ModelBase):
456+
def __init__(self, *args, **kwargs):
457+
super().__init__(*args, **kwargs)
458+
459+
if "text_config" in self.hparams:
460+
# move the text_config to the root level
461+
self.hparams = {**self.hparams, **self.hparams["text_config"]}
462+
463+
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
464+
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
465+
466+
@classmethod
467+
def __init_subclass__(cls):
468+
# can't use an abstract property, because overriding it without type errors
469+
# would require using decorated functions instead of simply defining the property
470+
if "model_arch" not in cls.__dict__:
471+
raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}")
472+
457473
def set_vocab(self):
458474
self._set_vocab_gpt2()
459475

@@ -1070,9 +1086,9 @@ def __init__(self, *args, **kwargs):
10701086
if self.model_arch != gguf.MODEL_ARCH.CLIP_VISION:
10711087
raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
10721088

1073-
# small hack to correct the number of layers
1074-
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, 128)
1075-
self.n_embd_text = self.find_hparam(["hidden_size", "n_embd"])
1089+
# get n_embd of the text model
1090+
text_config = {**self.hparams, **self.hparams["text_config"]}
1091+
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
10761092
assert self.n_embd_text > 0, "n_embd not found in hparams"
10771093

10781094
if "vision_config" not in self.hparams:
@@ -1081,6 +1097,9 @@ def __init__(self, *args, **kwargs):
10811097
self.global_config = self.hparams
10821098
self.hparams = self.hparams["vision_config"]
10831099

1100+
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"])
1101+
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count)
1102+
10841103
# load preprocessor config
10851104
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
10861105
self.preprocessor_config = json.load(f)
@@ -1098,7 +1117,7 @@ def set_gguf_parameters(self):
10981117
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
10991118
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
11001119
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
1101-
self.gguf_writer.add_vision_block_count(self.find_hparam(["num_hidden_layers"]))
1120+
self.gguf_writer.add_vision_block_count(self.block_count)
11021121
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
11031122

11041123
# preprocessor config
@@ -1719,23 +1738,12 @@ def prepare_tensors(self):
17191738
"LlamaForCausalLM",
17201739
"MistralForCausalLM",
17211740
"MixtralForCausalLM",
1722-
"Idefics3ForConditionalGeneration",
1723-
"SmolVLMForConditionalGeneration",
1741+
"VLlama3ForCausalLM",
17241742
"LlavaForConditionalGeneration")
17251743
class LlamaModel(TextModel):
17261744
model_arch = gguf.MODEL_ARCH.LLAMA
17271745
undo_permute = True
17281746

1729-
def __init__(self, *args, **kwargs):
1730-
super().__init__(*args, **kwargs)
1731-
# fix for SmolVLM2, missing `num_attention_heads` in config.json
1732-
if self.hparams["architectures"][0] == "SmolVLMForConditionalGeneration":
1733-
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
1734-
# fix for Pixtral, missing `num_attention_heads` in config.json
1735-
if self.hparams["architectures"][0] == "LlavaForConditionalGeneration" \
1736-
and self.hparams.get("model_type") == "mistral":
1737-
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
1738-
17391747
def set_vocab(self):
17401748
try:
17411749
self._set_vocab_sentencepiece()
@@ -1898,11 +1906,7 @@ class LlavaVisionModel(VisionModel):
18981906
def __init__(self, *args, **kwargs):
18991907
super().__init__(*args, **kwargs)
19001908
if self.hparams["model_type"] == "pixtral":
1901-
# fix missing config.json values
1902-
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
1903-
self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 24)
1904-
self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 4096)
1905-
self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1024)
1909+
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
19061910
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
19071911
self.img_break_tok_id = 12 # see tokenizer_config.json
19081912
else:
@@ -1913,7 +1917,6 @@ def set_gguf_parameters(self):
19131917
hparams = self.hparams
19141918
if hparams["model_type"] == "pixtral":
19151919
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
1916-
# default values below are taken from HF tranformers code
19171920
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
19181921
self.gguf_writer.add_vision_use_silu(True)
19191922

@@ -1944,13 +1947,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
19441947
class SmolVLMModel(VisionModel):
19451948
def __init__(self, *args, **kwargs):
19461949
super().__init__(*args, **kwargs)
1947-
# fix for SmolVLM2, missing some keys in config.json
1948-
# default values are taken from transformers code
19491950
if self.hparams["model_type"] == "smolvlm_vision":
1951+
# fix for SmolVLM2, missing some keys in config.json
1952+
# default values are taken from transformers code
19501953
self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1152)
19511954
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
19521955
self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 3072)
1953-
self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 12)
19541956

19551957
def set_gguf_parameters(self):
19561958
super().set_gguf_parameters()
@@ -3581,6 +3583,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
35813583

35823584
@ModelBase.register("NomicBertModel")
35833585
class NomicBertModel(BertModel):
3586+
model_arch = gguf.MODEL_ARCH.BERT
3587+
35843588
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
35853589
hparams = kwargs.pop("hparams", None)
35863590
if hparams is None:
@@ -5925,6 +5929,19 @@ def split_str_to_n_bytes(split_str: str) -> int:
59255929
return n
59265930

59275931

5932+
def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any = None) -> str:
5933+
hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
5934+
text_config = hparams.get("text_config", {})
5935+
vision_config = hparams.get("vision_config", {})
5936+
arch = hparams["architectures"][0]
5937+
# if "architectures" is found in the sub-config, use that instead
5938+
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
5939+
arch = text_config["architectures"][0]
5940+
elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
5941+
arch = vision_config["architectures"][0]
5942+
return arch
5943+
5944+
59285945
def main() -> None:
59295946
args = parse_args()
59305947

@@ -5977,16 +5994,15 @@ def main() -> None:
59775994

59785995
logger.info(f"Loading model: {dir_model.name}")
59795996

5980-
hparams = ModelBase.load_hparams(dir_model)
5981-
59825997
if args.mmproj:
59835998
if "mmproj" not in fname_out.name:
59845999
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
59856000

59866001
with torch.inference_mode():
59876002
output_type = ftype_map[args.outtype]
5988-
model_architecture = hparams["architectures"][0]
59896003
model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
6004+
model_architecture = get_model_architecture(dir_model, model_type)
6005+
logger.info(f"Model architecture: {model_architecture}")
59906006
try:
59916007
model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)
59926008
except NotImplementedError:

examples/llava/clip-impl.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
#include "gguf.h"
33
#include "clip.h"
44

5-
#include "clip.h"
6-
75
#include <climits>
86
#include <cstdarg>
97
#include <string>

ggml/src/ggml-cpu/simd-mappings.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
341341
#define GGML_F32_EPR 4
342342

343343
#define GGML_F32x4 vector float
344-
#define GGML_F32x4_ZERO 0.0f
344+
#define GGML_F32x4_ZERO {0.0f}
345345
#define GGML_F32x4_SET1 vec_splats
346346
#define GGML_F32x4_LOAD(p) vec_xl(0, p)
347347
#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCo
482482
const uint ib8 = (idx & 0x18) >> 3; // 0..3
483483
const uint iqs = 8 * ib32 + ib8;
484484

485-
const uint8_t qs = bl.block.qs[iqs];
485+
const uint qs = bl.block.qs[iqs];
486486
const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));
487487

488488
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28));

0 commit comments

Comments
 (0)