Skip to content

Commit 1595b69

Browse files
committed
add exaone model support
1 parent 3071c0a commit 1595b69

File tree

6 files changed

+301
-7
lines changed

6 files changed

+301
-7
lines changed

convert_hf_to_gguf.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
590590
if chkhsh == "855059429035d75a914d1eda9f10a876752e281a054a7a3d421ef0533e5b6249":
591591
# ref: https://huggingface.co/HuggingFaceTB/SmolLM-135M
592592
res = "smollm"
593+
if chkhsh == "4e2b24cc4770243d65a2c9ec19770a72f08cffc161adbb73fcbb6b7dd45a0aae":
594+
# ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct
595+
res = "exaone"
593596

594597
if res is None:
595598
logger.warning("\n")
@@ -3595,6 +3598,75 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
35953598
name = name.removeprefix("transformer.")
35963599
return [(self.map_tensor_name(name), data_torch)]
35973600

3601+
@Model.register("ExaoneForCausalLM")
3602+
class ExaoneModel(Model):
3603+
model_arch = gguf.MODEL_ARCH.EXAONE
3604+
3605+
def set_gguf_parameters(self):
3606+
hparams = self.hparams
3607+
3608+
assert(hparams["activation_function"] == "silu")
3609+
3610+
max_position_embeddings = hparams["max_position_embeddings"]
3611+
embed_dim = hparams["hidden_size"]
3612+
num_heads = hparams["num_attention_heads"]
3613+
num_kv_heads = hparams.get("num_key_value_heads", num_heads)
3614+
layer_norm_eps = hparams["layer_norm_epsilon"]
3615+
intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim
3616+
num_layers = hparams["num_layers"]
3617+
# ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0
3618+
# attention_dropout_rate = hparams["attention_dropout"]
3619+
# ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0
3620+
# embed_dropout_rate = hparams["embed_dropout"]
3621+
self.gguf_writer.add_embedding_length(embed_dim)
3622+
self.gguf_writer.add_head_count(num_heads)
3623+
self.gguf_writer.add_head_count_kv(num_kv_heads)
3624+
self.gguf_writer.add_context_length(max_position_embeddings)
3625+
self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps)
3626+
self.gguf_writer.add_feed_forward_length(intermediate_size)
3627+
self.gguf_writer.add_block_count(num_layers)
3628+
3629+
if (rope_theta := self.hparams.get("rope_theta")) is not None:
3630+
self.gguf_writer.add_rope_freq_base(rope_theta)
3631+
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True)
3632+
rotary_factor = rotary_factor if rotary_factor is not None else 1.0
3633+
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
3634+
if hparams.get("rope_scaling") is not None and "factor" in hparams["rope_scaling"]:
3635+
if hparams["rope_scaling"].get("type") == "linear":
3636+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
3637+
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
3638+
3639+
def prepare_tensors(self):
3640+
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
3641+
if rope_scaling.get("rope_type", '').lower() == "llama3":
3642+
base = self.hparams.get("rope_theta", 10000.0)
3643+
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
3644+
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
3645+
3646+
factor = rope_scaling.get("factor", 8.0)
3647+
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
3648+
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
3649+
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
3650+
3651+
low_freq_wavelen = old_context_len / low_freq_factor
3652+
high_freq_wavelen = old_context_len / high_freq_factor
3653+
assert low_freq_wavelen != high_freq_wavelen
3654+
3655+
rope_factors = []
3656+
for freq in freqs:
3657+
wavelen = 2 * math.pi / freq
3658+
if wavelen < high_freq_wavelen:
3659+
rope_factors.append(1)
3660+
elif wavelen > low_freq_wavelen:
3661+
rope_factors.append(factor)
3662+
else:
3663+
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
3664+
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
3665+
3666+
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
3667+
3668+
super().prepare_tensors()
3669+
35983670
###### CONVERSION LOGIC ######
35993671

36003672

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class TOKENIZER_TYPE(IntEnum):
9494
{"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", },
9595
{"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", },
9696
{"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", },
97+
{"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", },
9798
]
9899

99100

gguf-py/gguf/constants.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ class MODEL_ARCH(IntEnum):
218218
BITNET = auto()
219219
T5 = auto()
220220
JAIS = auto()
221+
EXAONE = auto()
221222

222223

223224
class MODEL_TENSOR(IntEnum):
@@ -345,6 +346,7 @@ class MODEL_TENSOR(IntEnum):
345346
MODEL_ARCH.BITNET: "bitnet",
346347
MODEL_ARCH.T5: "t5",
347348
MODEL_ARCH.JAIS: "jais",
349+
MODEL_ARCH.EXAONE: "exaone",
348350
}
349351

350352
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -1048,6 +1050,22 @@ class MODEL_TENSOR(IntEnum):
10481050
MODEL_TENSOR.FFN_GATE,
10491051
MODEL_TENSOR.FFN_UP,
10501052
],
1053+
MODEL_ARCH.EXAONE: [
1054+
MODEL_TENSOR.TOKEN_EMBD,
1055+
MODEL_TENSOR.OUTPUT_NORM,
1056+
MODEL_TENSOR.OUTPUT,
1057+
MODEL_TENSOR.ROPE_FREQS,
1058+
MODEL_TENSOR.ATTN_NORM,
1059+
MODEL_TENSOR.ATTN_Q,
1060+
MODEL_TENSOR.ATTN_K,
1061+
MODEL_TENSOR.ATTN_V,
1062+
MODEL_TENSOR.ATTN_OUT,
1063+
MODEL_TENSOR.ATTN_ROT_EMBD,
1064+
MODEL_TENSOR.FFN_NORM,
1065+
MODEL_TENSOR.FFN_GATE,
1066+
MODEL_TENSOR.FFN_DOWN,
1067+
MODEL_TENSOR.FFN_UP,
1068+
],
10511069
# TODO
10521070
}
10531071

gguf-py/gguf/tensor_mapping.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class TensorNameMap:
1010
# Token embeddings
1111
MODEL_TENSOR.TOKEN_EMBD: (
1212
"gpt_neox.embed_in", # gptneox
13-
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais
13+
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
1616
"model.embed_tokens", # llama-hf
@@ -52,7 +52,7 @@ class TensorNameMap:
5252
# Output
5353
MODEL_TENSOR.OUTPUT: (
5454
"embed_out", # gptneox
55-
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais
55+
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais exaone
5656
"output", # llama-pth bloom internlm2
5757
"word_embeddings_for_head", # persimmon
5858
"lm_head.linear", # phi2
@@ -62,7 +62,7 @@ class TensorNameMap:
6262
# Output norm
6363
MODEL_TENSOR.OUTPUT_NORM: (
6464
"gpt_neox.final_layer_norm", # gptneox
65-
"transformer.ln_f", # gpt2 gpt-j falcon jais
65+
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
6666
"model.norm", # llama-hf baichuan internlm2
6767
"norm", # llama-pth
6868
"transformer.norm_f", # mpt dbrx
@@ -88,7 +88,7 @@ class TensorNameMap:
8888
# Attention norm
8989
MODEL_TENSOR.ATTN_NORM: (
9090
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
91-
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais
91+
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais exaone
9292
"transformer.blocks.{bid}.norm_1", # mpt
9393
"transformer.h.{bid}.input_layernorm", # falcon7b
9494
"h.{bid}.input_layernorm", # bloom
@@ -142,6 +142,7 @@ class TensorNameMap:
142142
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
143143
"model.layers.{bid}.attention.wq", # internlm2
144144
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
145+
"transformer.h.{bid}.attn.attention.q_proj", # exaone
145146
),
146147

147148
# Attention key
@@ -154,6 +155,7 @@ class TensorNameMap:
154155
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
155156
"model.layers.{bid}.attention.wk", # internlm2
156157
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
158+
"transformer.h.{bid}.attn.attention.k_proj", # exaone
157159
),
158160

159161
# Attention value
@@ -165,7 +167,8 @@ class TensorNameMap:
165167
"transformer.h.{bid}.attn.v", # refact
166168
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
167169
"model.layers.{bid}.attention.wv", # internlm2
168-
"transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok
170+
"transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
171+
"transformer.h.{bid}.attn.attention.v_proj", # exaone
169172
),
170173

171174
# Attention output
@@ -190,6 +193,7 @@ class TensorNameMap:
190193
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
191194
"encoder.layers.{bid}.self_attention.dense", # chatglm
192195
"transformer.layers.{bid}.attn.out_proj", # openelm
196+
"transformer.h.{bid}.attn.attention.out_proj", # exaone
193197
),
194198

195199
# Attention output norm
@@ -215,7 +219,7 @@ class TensorNameMap:
215219
# Feed-forward norm
216220
MODEL_TENSOR.FFN_NORM: (
217221
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
218-
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais
222+
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
219223
"h.{bid}.post_attention_layernorm", # bloom
220224
"transformer.blocks.{bid}.norm_2", # mpt
221225
"model.layers.{bid}.post_attention_layernorm", # llama-hf
@@ -277,6 +281,7 @@ class TensorNameMap:
277281
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
278282
"model.layers.{bid}.residual_mlp.w3", # arctic
279283
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
284+
"transformer.h.{bid}.mlp.c_fc_1", # exaone
280285
),
281286

282287
MODEL_TENSOR.FFN_UP_EXP: (
@@ -308,6 +313,7 @@ class TensorNameMap:
308313
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
309314
"transformer.h.{bid}.mlp.linear_1", # refact
310315
"model.layers.{bid}.residual_mlp.w1", # arctic
316+
"transformer.h.{bid}.mlp.c_fc_0", # exaone
311317
),
312318

313319
MODEL_TENSOR.FFN_GATE_EXP: (
@@ -347,6 +353,7 @@ class TensorNameMap:
347353
"model.layers.{bid}.residual_mlp.w2", # arctic
348354
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
349355
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
356+
"model.layers.h.{bid}.mlp.c_proj", # exaone
350357
),
351358

352359
MODEL_TENSOR.FFN_DOWN_EXP: (

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ extern "C" {
9393
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
9494
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
9595
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
96+
LLAMA_VOCAB_PRE_TYPE_EXAONE = 23,
9697
};
9798

9899
// note: these values should be synchronized with ggml_rope

0 commit comments

Comments
 (0)