Skip to content

Commit 2b077c9

Browse files
author
fmz
committed
Add JAIS model(s)
1 parent f3f6542 commit 2b077c9

File tree

8 files changed

+279
-21
lines changed

8 files changed

+279
-21
lines changed

convert-hf-to-gguf-update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class TOKENIZER_TYPE(IntEnum):
8585
{"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", },
8686
{"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", },
8787
{"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", },
88+
{"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
8889
]
8990

9091

convert-hf-to-gguf.py

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -427,9 +427,6 @@ def get_vocab_base_pre(self, tokenizer) -> str:
427427
# NOTE: if you get an error here, you need to update the convert-hf-to-gguf-update.py script
428428
# or pull the latest version of the model from Huggingface
429429
# don't edit the hashes manually!
430-
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
431-
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
432-
res = "llama-bpe"
433430
if chkhsh == "049ecf7629871e3041641907f3de7c733e4dbfdc736f57d882ba0b0845599754":
434431
# ref: https://huggingface.co/deepseek-ai/deepseek-llm-7b-base
435432
res = "deepseek-llm"
@@ -457,18 +454,12 @@ def get_vocab_base_pre(self, tokenizer) -> str:
457454
if chkhsh == "6221ad2852e85ce96f791f476e0b390cf9b474c9e3d1362f53a24a06dc8220ff":
458455
# ref: https://huggingface.co/smallcloudai/Refact-1_6-base
459456
res = "refact"
460-
if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8":
461-
# ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01
462-
res = "command-r"
463457
if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea":
464458
# ref: https://huggingface.co/Qwen/Qwen1.5-7B
465459
res = "qwen2"
466460
if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
467461
# ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf
468462
res = "olmo"
469-
if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e":
470-
# ref: https://huggingface.co/databricks/dbrx-base
471-
res = "dbrx"
472463
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
473464
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en
474465
res = "jina-v2-en"
@@ -487,6 +478,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
487478
if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a":
488479
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code
489480
res = "jina-v2-code"
481+
if chkhsh == "b53802fb28e26d645c3a310b34bfe07da813026ec7c7716883404d5e0f8b1901":
482+
# ref: https://huggingface.co/core42/jais-13b
483+
res = "jais"
490484

491485
if res is None:
492486
logger.warning("\n")
@@ -2774,6 +2768,79 @@ def write_tensors(self):
27742768
if len(experts) > 0:
27752769
raise ValueError(f"Unprocessed experts: {experts}")
27762770

2771+
@Model.register("JAISLMHeadModel")
2772+
class JaisModel(Model):
2773+
model_arch = gguf.MODEL_ARCH.JAIS
2774+
2775+
def __init__(self, *args, **kwargs):
2776+
super().__init__(*args, **kwargs)
2777+
2778+
# SwigLU activation
2779+
assert self.hparams["activation_function"] == "swiglu"
2780+
# ALiBi position embedding
2781+
assert self.hparams["position_embedding_type"] == "alibi"
2782+
2783+
# Embeddings scale
2784+
self.embeddings_scale = 1.0
2785+
# note: For some JAIS flavors, output is tied to (same as) wte in original model
2786+
self.output_is_wte = False
2787+
if 'mup_embeddings_scale' in self.hparams:
2788+
self.output_is_wte = True # Hack (?)
2789+
self.embeddings_scale = self.hparams['mup_embeddings_scale']
2790+
elif 'embeddings_scale' in self.hparams:
2791+
self.embeddings_scale = self.hparams['embeddings_scale']
2792+
else:
2793+
assert False
2794+
2795+
self.width_scale = 1.0
2796+
if 'mup_output_alpha' in self.hparams:
2797+
assert 'mup_width_scale' in self.hparams
2798+
self.width_scale = self.hparams['mup_output_alpha'] * self.hparams['mup_width_scale']
2799+
elif 'width_scale' in self.hparams:
2800+
self.width_scale = self.hparams['width_scale']
2801+
else:
2802+
assert False
2803+
2804+
def set_vocab(self):
2805+
self._set_vocab_gpt2()
2806+
2807+
def set_gguf_parameters(self):
2808+
self.gguf_writer.add_name(self.dir_model.name)
2809+
self.gguf_writer.add_block_count(self.hparams["n_layer"])
2810+
self.gguf_writer.add_context_length(self.hparams["n_positions"])
2811+
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
2812+
self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"])
2813+
self.gguf_writer.add_head_count(self.hparams["n_head"])
2814+
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
2815+
self.gguf_writer.add_file_type(self.ftype)
2816+
2817+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2818+
del bid # unused
2819+
2820+
tensors: list[tuple[str, Tensor]] = []
2821+
2822+
# we don't need these
2823+
if name.endswith((".attn.bias", "relative_pe.slopes")):
2824+
return tensors
2825+
2826+
if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_fc2.weight")):
2827+
data_torch = data_torch.transpose(1, 0)
2828+
2829+
new_name = self.map_tensor_name(name)
2830+
2831+
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
2832+
tensors.append((new_name, data_torch * self.embeddings_scale))
2833+
if self.output_is_wte:
2834+
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch * self.width_scale))
2835+
elif new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
2836+
assert not self.output_is_wte
2837+
tensors.append((new_name, data_torch * self.width_scale))
2838+
else:
2839+
tensors.append((new_name, data_torch))
2840+
2841+
return tensors
2842+
2843+
27772844

27782845
@Model.register("T5ForConditionalGeneration")
27792846
@Model.register("T5WithLMHeadModel")

examples/main/main.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,6 @@ int main(int argc, char ** argv) {
733733

734734
// Console/Stream Output
735735
fprintf(stdout, "%s", token_str.c_str());
736-
737736
// Record Displayed Tokens To Log
738737
// Note: Generated tokens are created one by one hence this check
739738
if (embd.size() > 1) {

ggml/src/ggml.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13516,13 +13516,13 @@ static void ggml_compute_forward_soft_max_f32(
1351613516
} else {
1351713517
for (int i = 0; i < nc; ++i) {
1351813518
wp[i] += slope*mp_f32[i];
13519+
1351913520
}
1352013521
}
1352113522
}
1352213523

1352313524
#ifndef NDEBUG
1352413525
for (int i = 0; i < nc; ++i) {
13525-
//printf("p[%d] = %f\n", i, p[i]);
1352613526
assert(!isnan(wp[i]));
1352713527
}
1352813528
#endif

gguf-py/gguf/constants.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class MODEL_ARCH(IntEnum):
160160
DEEPSEEK2 = auto()
161161
BITNET = auto()
162162
T5 = auto()
163+
JAIS = auto()
163164

164165

165166
class MODEL_TENSOR(IntEnum):
@@ -280,6 +281,7 @@ class MODEL_TENSOR(IntEnum):
280281
MODEL_ARCH.DEEPSEEK2: "deepseek2",
281282
MODEL_ARCH.BITNET: "bitnet",
282283
MODEL_ARCH.T5: "t5",
284+
MODEL_ARCH.JAIS: "jais",
283285
}
284286

285287
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -928,6 +930,18 @@ class MODEL_TENSOR(IntEnum):
928930
MODEL_TENSOR.ENC_FFN_UP,
929931
MODEL_TENSOR.ENC_OUTPUT_NORM,
930932
],
933+
MODEL_ARCH.JAIS: [
934+
MODEL_TENSOR.TOKEN_EMBD,
935+
MODEL_TENSOR.OUTPUT_NORM,
936+
MODEL_TENSOR.OUTPUT,
937+
MODEL_TENSOR.ATTN_NORM,
938+
MODEL_TENSOR.ATTN_QKV,
939+
MODEL_TENSOR.ATTN_OUT,
940+
MODEL_TENSOR.FFN_NORM,
941+
MODEL_TENSOR.FFN_DOWN,
942+
MODEL_TENSOR.FFN_GATE,
943+
MODEL_TENSOR.FFN_UP,
944+
],
931945
# TODO
932946
}
933947

gguf-py/gguf/tensor_mapping.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class TensorNameMap:
1010
# Token embeddings
1111
MODEL_TENSOR.TOKEN_EMBD: (
1212
"gpt_neox.embed_in", # gptneox
13-
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx
13+
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
1616
"model.embed_tokens", # llama-hf
@@ -49,7 +49,7 @@ class TensorNameMap:
4949
# Output
5050
MODEL_TENSOR.OUTPUT: (
5151
"embed_out", # gptneox
52-
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx
52+
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais
5353
"output", # llama-pth bloom internlm2
5454
"word_embeddings_for_head", # persimmon
5555
"lm_head.linear", # phi2
@@ -58,7 +58,7 @@ class TensorNameMap:
5858
# Output norm
5959
MODEL_TENSOR.OUTPUT_NORM: (
6060
"gpt_neox.final_layer_norm", # gptneox
61-
"transformer.ln_f", # gpt2 gpt-j falcon
61+
"transformer.ln_f", # gpt2 gpt-j falcon jais
6262
"model.norm", # llama-hf baichuan internlm2
6363
"norm", # llama-pth
6464
"transformer.norm_f", # mpt dbrx
@@ -81,7 +81,7 @@ class TensorNameMap:
8181
# Attention norm
8282
MODEL_TENSOR.ATTN_NORM: (
8383
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
84-
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen
84+
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais
8585
"transformer.blocks.{bid}.norm_1", # mpt
8686
"transformer.h.{bid}.input_layernorm", # falcon7b
8787
"h.{bid}.input_layernorm", # bloom
@@ -109,7 +109,7 @@ class TensorNameMap:
109109
# Attention query-key-value
110110
MODEL_TENSOR.ATTN_QKV: (
111111
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
112-
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen
112+
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen jais
113113
"transformer.blocks.{bid}.attn.Wqkv", # mpt
114114
"transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx
115115
"transformer.h.{bid}.self_attention.query_key_value", # falcon
@@ -160,7 +160,7 @@ class TensorNameMap:
160160
# Attention output
161161
MODEL_TENSOR.ATTN_OUT: (
162162
"gpt_neox.layers.{bid}.attention.dense", # gptneox
163-
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
163+
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen jais
164164
"transformer.blocks.{bid}.attn.out_proj", # mpt
165165
"transformer.h.{bid}.self_attention.dense", # falcon
166166
"h.{bid}.self_attention.dense", # bloom
@@ -198,7 +198,7 @@ class TensorNameMap:
198198
# Feed-forward norm
199199
MODEL_TENSOR.FFN_NORM: (
200200
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
201-
"transformer.h.{bid}.ln_2", # gpt2 refact qwen
201+
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais
202202
"h.{bid}.post_attention_layernorm", # bloom
203203
"transformer.blocks.{bid}.norm_2", # mpt
204204
"model.layers.{bid}.post_attention_layernorm", # llama-hf
@@ -225,7 +225,7 @@ class TensorNameMap:
225225
# Feed-forward up
226226
MODEL_TENSOR.FFN_UP: (
227227
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
228-
"transformer.h.{bid}.mlp.c_fc", # gpt2
228+
"transformer.h.{bid}.mlp.c_fc", # gpt2 jais
229229
"transformer.blocks.{bid}.ffn.up_proj", # mpt
230230
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
231231
"h.{bid}.mlp.dense_h_to_4h", # bloom
@@ -271,6 +271,7 @@ class TensorNameMap:
271271
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
272272
"layers.{bid}.feed_forward.w1", # llama-pth
273273
"transformer.h.{bid}.mlp.w2", # qwen
274+
"transformer.h.{bid}.mlp.c_fc2", # jais
274275
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
275276
"model.layers.{bid}.feed_forward.w1", # internlm2
276277
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
@@ -294,7 +295,7 @@ class TensorNameMap:
294295
# Feed-forward down
295296
MODEL_TENSOR.FFN_DOWN: (
296297
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
297-
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen
298+
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen jais
298299
"transformer.blocks.{bid}.ffn.down_proj", # mpt
299300
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
300301
"h.{bid}.mlp.dense_4h_to_h", # bloom

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ extern "C" {
8888
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
8989
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
9090
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
91+
LLAMA_VOCAB_PRE_TYPE_JAIS = 16,
9192
};
9293

9394
// note: these values should be synchronized with ggml_rope

0 commit comments

Comments
 (0)