Skip to content

Commit c679e0c

Browse files
mscheong01ggerganovcompilade
authored
llama : add EXAONE model support (#9025)
* add exaone model support * add chat template * fix whitespace Co-authored-by: Georgi Gerganov <[email protected]> * add ftype * add exaone pre-tokenizer in `llama-vocab.cpp` Co-Authored-By: compilade <[email protected]> * fix lint Co-Authored-By: compilade <[email protected]> * add `EXAONE` to supported models in `README.md` * fix space Co-authored-by: compilade <[email protected]> --------- Co-authored-by: Georgi Gerganov <[email protected]> Co-authored-by: compilade <[email protected]> Co-authored-by: compilade <[email protected]>
1 parent fb487bb commit c679e0c

File tree

8 files changed

+320
-7
lines changed

8 files changed

+320
-7
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ Typically finetunes of the base models below are supported as well.
105105
- [x] [Open Elm models](https://huggingface.co/collections/apple/openelm-instruct-models-6619ad295d7ae9f868b759ca)
106106
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
107107
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
108+
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
108109

109110
(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))
110111

convert_hf_to_gguf.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
596596
if chkhsh == "bc01ce58980e1db43859146dc51b1758b3b88729b217a74792e9f8d43e479d21":
597597
# ref: https://huggingface.co/TurkuNLP/gpt3-finnish-small
598598
res = "gpt3-finnish"
599+
if chkhsh == "4e2b24cc4770243d65a2c9ec19770a72f08cffc161adbb73fcbb6b7dd45a0aae":
600+
# ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct
601+
res = "exaone"
599602

600603
if res is None:
601604
logger.warning("\n")
@@ -3781,6 +3784,77 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
37813784

37823785
return [(self.map_tensor_name(name), data_torch)]
37833786

3787+
3788+
@Model.register("ExaoneForCausalLM")
3789+
class ExaoneModel(Model):
3790+
model_arch = gguf.MODEL_ARCH.EXAONE
3791+
3792+
def set_gguf_parameters(self):
3793+
hparams = self.hparams
3794+
3795+
assert(hparams["activation_function"] == "silu")
3796+
3797+
max_position_embeddings = hparams["max_position_embeddings"]
3798+
embed_dim = hparams["hidden_size"]
3799+
num_heads = hparams["num_attention_heads"]
3800+
num_kv_heads = hparams.get("num_key_value_heads", num_heads)
3801+
layer_norm_eps = hparams["layer_norm_epsilon"]
3802+
intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim
3803+
num_layers = hparams["num_layers"]
3804+
# ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0
3805+
# attention_dropout_rate = hparams["attention_dropout"]
3806+
# ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0
3807+
# embed_dropout_rate = hparams["embed_dropout"]
3808+
self.gguf_writer.add_embedding_length(embed_dim)
3809+
self.gguf_writer.add_head_count(num_heads)
3810+
self.gguf_writer.add_head_count_kv(num_kv_heads)
3811+
self.gguf_writer.add_context_length(max_position_embeddings)
3812+
self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps)
3813+
self.gguf_writer.add_feed_forward_length(intermediate_size)
3814+
self.gguf_writer.add_block_count(num_layers)
3815+
self.gguf_writer.add_file_type(self.ftype)
3816+
3817+
if (rope_theta := self.hparams.get("rope_theta")) is not None:
3818+
self.gguf_writer.add_rope_freq_base(rope_theta)
3819+
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True)
3820+
rotary_factor = rotary_factor if rotary_factor is not None else 1.0
3821+
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
3822+
if hparams.get("rope_scaling") is not None and "factor" in hparams["rope_scaling"]:
3823+
if hparams["rope_scaling"].get("type") == "linear":
3824+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
3825+
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
3826+
3827+
def prepare_tensors(self):
3828+
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
3829+
if rope_scaling.get("rope_type", '').lower() == "llama3":
3830+
base = self.hparams.get("rope_theta", 10000.0)
3831+
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
3832+
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
3833+
3834+
factor = rope_scaling.get("factor", 8.0)
3835+
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
3836+
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
3837+
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
3838+
3839+
low_freq_wavelen = old_context_len / low_freq_factor
3840+
high_freq_wavelen = old_context_len / high_freq_factor
3841+
assert low_freq_wavelen != high_freq_wavelen
3842+
3843+
rope_factors = []
3844+
for freq in freqs:
3845+
wavelen = 2 * math.pi / freq
3846+
if wavelen < high_freq_wavelen:
3847+
rope_factors.append(1)
3848+
elif wavelen > low_freq_wavelen:
3849+
rope_factors.append(factor)
3850+
else:
3851+
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
3852+
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
3853+
3854+
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
3855+
3856+
super().prepare_tensors()
3857+
37843858
###### CONVERSION LOGIC ######
37853859

37863860

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class TOKENIZER_TYPE(IntEnum):
9696
{"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", },
9797
{'name': "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom", },
9898
{'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", },
99+
{"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", },
99100
]
100101

101102

gguf-py/gguf/constants.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ class MODEL_ARCH(IntEnum):
220220
T5ENCODER = auto()
221221
JAIS = auto()
222222
NEMOTRON = auto()
223+
EXAONE = auto()
223224

224225

225226
class MODEL_TENSOR(IntEnum):
@@ -349,6 +350,7 @@ class MODEL_TENSOR(IntEnum):
349350
MODEL_ARCH.T5ENCODER: "t5encoder",
350351
MODEL_ARCH.JAIS: "jais",
351352
MODEL_ARCH.NEMOTRON: "nemotron",
353+
MODEL_ARCH.EXAONE: "exaone",
352354
}
353355

354356
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -1082,6 +1084,22 @@ class MODEL_TENSOR(IntEnum):
10821084
MODEL_TENSOR.FFN_DOWN,
10831085
MODEL_TENSOR.FFN_UP,
10841086
],
1087+
MODEL_ARCH.EXAONE: [
1088+
MODEL_TENSOR.TOKEN_EMBD,
1089+
MODEL_TENSOR.OUTPUT_NORM,
1090+
MODEL_TENSOR.OUTPUT,
1091+
MODEL_TENSOR.ROPE_FREQS,
1092+
MODEL_TENSOR.ATTN_NORM,
1093+
MODEL_TENSOR.ATTN_Q,
1094+
MODEL_TENSOR.ATTN_K,
1095+
MODEL_TENSOR.ATTN_V,
1096+
MODEL_TENSOR.ATTN_OUT,
1097+
MODEL_TENSOR.ATTN_ROT_EMBD,
1098+
MODEL_TENSOR.FFN_NORM,
1099+
MODEL_TENSOR.FFN_GATE,
1100+
MODEL_TENSOR.FFN_DOWN,
1101+
MODEL_TENSOR.FFN_UP,
1102+
],
10851103
# TODO
10861104
}
10871105

gguf-py/gguf/tensor_mapping.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class TensorNameMap:
1010
# Token embeddings
1111
MODEL_TENSOR.TOKEN_EMBD: (
1212
"gpt_neox.embed_in", # gptneox
13-
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais
13+
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
1616
"model.embed_tokens", # llama-hf nemotron
@@ -52,7 +52,7 @@ class TensorNameMap:
5252
# Output
5353
MODEL_TENSOR.OUTPUT: (
5454
"embed_out", # gptneox
55-
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron
55+
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone
5656
"output", # llama-pth bloom internlm2
5757
"word_embeddings_for_head", # persimmon
5858
"lm_head.linear", # phi2
@@ -62,7 +62,7 @@ class TensorNameMap:
6262
# Output norm
6363
MODEL_TENSOR.OUTPUT_NORM: (
6464
"gpt_neox.final_layer_norm", # gptneox
65-
"transformer.ln_f", # gpt2 gpt-j falcon jais
65+
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
6666
"model.norm", # llama-hf baichuan internlm2
6767
"norm", # llama-pth
6868
"transformer.norm_f", # mpt dbrx
@@ -89,7 +89,7 @@ class TensorNameMap:
8989
# Attention norm
9090
MODEL_TENSOR.ATTN_NORM: (
9191
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
92-
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais
92+
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais exaone
9393
"transformer.blocks.{bid}.norm_1", # mpt
9494
"transformer.h.{bid}.input_layernorm", # falcon7b
9595
"h.{bid}.input_layernorm", # bloom
@@ -143,6 +143,7 @@ class TensorNameMap:
143143
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
144144
"model.layers.{bid}.attention.wq", # internlm2
145145
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
146+
"transformer.h.{bid}.attn.attention.q_proj", # exaone
146147
),
147148

148149
# Attention key
@@ -155,6 +156,7 @@ class TensorNameMap:
155156
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
156157
"model.layers.{bid}.attention.wk", # internlm2
157158
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
159+
"transformer.h.{bid}.attn.attention.k_proj", # exaone
158160
),
159161

160162
# Attention value
@@ -166,7 +168,8 @@ class TensorNameMap:
166168
"transformer.h.{bid}.attn.v", # refact
167169
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
168170
"model.layers.{bid}.attention.wv", # internlm2
169-
"transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok
171+
"transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
172+
"transformer.h.{bid}.attn.attention.v_proj", # exaone
170173
),
171174

172175
# Attention output
@@ -191,6 +194,7 @@ class TensorNameMap:
191194
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
192195
"encoder.layers.{bid}.self_attention.dense", # chatglm
193196
"transformer.layers.{bid}.attn.out_proj", # openelm
197+
"transformer.h.{bid}.attn.attention.out_proj", # exaone
194198
),
195199

196200
# Attention output norm
@@ -216,7 +220,7 @@ class TensorNameMap:
216220
# Feed-forward norm
217221
MODEL_TENSOR.FFN_NORM: (
218222
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
219-
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais
223+
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
220224
"h.{bid}.post_attention_layernorm", # bloom
221225
"transformer.blocks.{bid}.norm_2", # mpt
222226
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron
@@ -278,6 +282,7 @@ class TensorNameMap:
278282
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
279283
"model.layers.{bid}.residual_mlp.w3", # arctic
280284
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
285+
"transformer.h.{bid}.mlp.c_fc_1", # exaone
281286
),
282287

283288
MODEL_TENSOR.FFN_UP_EXP: (
@@ -309,6 +314,7 @@ class TensorNameMap:
309314
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
310315
"transformer.h.{bid}.mlp.linear_1", # refact
311316
"model.layers.{bid}.residual_mlp.w1", # arctic
317+
"transformer.h.{bid}.mlp.c_fc_0", # exaone
312318
),
313319

314320
MODEL_TENSOR.FFN_GATE_EXP: (
@@ -348,6 +354,7 @@ class TensorNameMap:
348354
"model.layers.{bid}.residual_mlp.w2", # arctic
349355
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
350356
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
357+
"model.layers.h.{bid}.mlp.c_proj", # exaone
351358
),
352359

353360
MODEL_TENSOR.FFN_DOWN_EXP: (

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ extern "C" {
9595
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
9696
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
9797
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
98+
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
9899
};
99100

100101
enum llama_rope_type {

src/llama-vocab.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ struct llm_tokenizer_bpe {
388388
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
389389
case LLAMA_VOCAB_PRE_TYPE_SMOLLM:
390390
case LLAMA_VOCAB_PRE_TYPE_CODESHELL:
391+
case LLAMA_VOCAB_PRE_TYPE_EXAONE:
391392
regex_exprs = {
392393
"\\p{N}",
393394
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",

0 commit comments

Comments
 (0)