Skip to content

Commit 291f2b6

Browse files
huydt84huydt-bti
andauthored
llama : add support for DistilBert (#13907)
* add distilbert * small fixes * add note for LLM_ARCH_DISTIL_BERT * Use MODEL_ARCH.BERT for DistilBert --------- Co-authored-by: dinhhuy <[email protected]>
1 parent 2c90da4 commit 291f2b6

File tree

3 files changed

+37
-6
lines changed

3 files changed

+37
-6
lines changed

convert_hf_to_gguf.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,15 +523,15 @@ def set_gguf_parameters(self):
523523
self.gguf_writer.add_context_length(n_ctx)
524524
logger.info(f"gguf: context length = {n_ctx}")
525525

526-
if (n_embd := self.find_hparam(["hidden_size", "n_embd"], optional=True)) is not None:
526+
if (n_embd := self.find_hparam(["hidden_size", "n_embd", "dim"], optional=True)) is not None:
527527
self.gguf_writer.add_embedding_length(n_embd)
528528
logger.info(f"gguf: embedding length = {n_embd}")
529529

530-
if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
530+
if (n_ff := self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"], optional=True)) is not None:
531531
self.gguf_writer.add_feed_forward_length(n_ff)
532532
logger.info(f"gguf: feed forward length = {n_ff}")
533533

534-
if (n_head := self.find_hparam(["num_attention_heads", "n_head"], optional=True)) is not None:
534+
if (n_head := self.find_hparam(["num_attention_heads", "n_head", "n_heads"], optional=True)) is not None:
535535
self.gguf_writer.add_head_count(n_head)
536536
logger.info(f"gguf: head count = {n_head}")
537537

@@ -3907,6 +3907,26 @@ def _xlmroberta_set_vocab(self) -> None:
39073907
self.gguf_writer.add_add_eos_token(True)
39083908

39093909

3910+
@ModelBase.register("DistilBertModel", "DistilBertForMaskedLM", "DistilBertForSequenceClassification")
3911+
class DistilBertModel(BertModel):
3912+
model_arch = gguf.MODEL_ARCH.BERT
3913+
3914+
def set_gguf_parameters(self):
3915+
self.gguf_writer.add_layer_norm_eps(1e-12)
3916+
logger.info("gguf: layer norm epsilon = 1e-12")
3917+
super().set_gguf_parameters()
3918+
3919+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3920+
if name.startswith("distilbert."):
3921+
name = name[11:]
3922+
3923+
# These layers act as MLM head, so we don't need them
3924+
if name.startswith("vocab_"):
3925+
return []
3926+
3927+
return super().modify_tensors(data_torch, name, bid)
3928+
3929+
39103930
@ModelBase.register("RobertaModel", "RobertaForSequenceClassification")
39113931
class RobertaModel(BertModel):
39123932
model_arch = gguf.MODEL_ARCH.BERT

gguf-py/gguf/tensor_mapping.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ class TensorNameMap:
169169
"model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom
170170
"layers.{bid}.attention.wq", # llama-pth
171171
"encoder.layer.{bid}.attention.self.query", # bert
172+
"transformer.layer.{bid}.attention.q_lin", # distillbert
172173
"transformer.h.{bid}.attn.q_proj", # gpt-j
173174
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
174175
"model.layers.{bid}.attention.wq", # internlm2
@@ -183,6 +184,7 @@ class TensorNameMap:
183184
"model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom
184185
"layers.{bid}.attention.wk", # llama-pth
185186
"encoder.layer.{bid}.attention.self.key", # bert
187+
"transformer.layer.{bid}.attention.k_lin", # distillbert
186188
"transformer.h.{bid}.attn.k_proj", # gpt-j
187189
"transformer.h.{bid}.attn.k", # refact
188190
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
@@ -197,6 +199,7 @@ class TensorNameMap:
197199
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe
198200
"layers.{bid}.attention.wv", # llama-pth
199201
"encoder.layer.{bid}.attention.self.value", # bert
202+
"transformer.layer.{bid}.attention.v_lin", # distillbert
200203
"transformer.h.{bid}.attn.v_proj", # gpt-j
201204
"transformer.h.{bid}.attn.v", # refact
202205
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
@@ -217,6 +220,7 @@ class TensorNameMap:
217220
"model.layers.{bid}.self_attn.linear_attn", # deci
218221
"layers.{bid}.attention.wo", # llama-pth
219222
"encoder.layer.{bid}.attention.output.dense", # bert
223+
"transformer.layer.{bid}.attention.out_lin", # distillbert
220224
"transformer.h.{bid}.attn.out_proj", # gpt-j
221225
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
222226
"model.layers.{bid}.self_attn.dense", # persimmon
@@ -237,6 +241,7 @@ class TensorNameMap:
237241
# Attention output norm
238242
MODEL_TENSOR.ATTN_OUT_NORM: (
239243
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
244+
"transformer.layer.{bid}.sa_layer_norm", # distillbert
240245
"encoder.layers.{bid}.norm1", # nomic-bert
241246
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
242247
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
@@ -313,6 +318,7 @@ class TensorNameMap:
313318
"model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2
314319
"layers.{bid}.feed_forward.w3", # llama-pth
315320
"encoder.layer.{bid}.intermediate.dense", # bert
321+
"transformer.layer.{bid}.ffn.lin1", # distillbert
316322
"transformer.h.{bid}.mlp.fc_in", # gpt-j
317323
"transformer.h.{bid}.mlp.linear_3", # refact
318324
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
@@ -396,6 +402,7 @@ class TensorNameMap:
396402
"model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2
397403
"layers.{bid}.feed_forward.w2", # llama-pth
398404
"encoder.layer.{bid}.output.dense", # bert
405+
"transformer.layer.{bid}.ffn.lin2", # distillbert
399406
"transformer.h.{bid}.mlp.fc_out", # gpt-j
400407
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
401408
"model.layers.{bid}.mlp.dense_4h_to_h", # persimmon
@@ -457,6 +464,7 @@ class TensorNameMap:
457464

458465
MODEL_TENSOR.LAYER_OUT_NORM: (
459466
"encoder.layer.{bid}.output.LayerNorm", # bert
467+
"transformer.layer.{bid}.output_layer_norm", # distillbert
460468
"encoder.layers.{bid}.norm2", # nomic-bert
461469
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
462470
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
@@ -827,6 +835,7 @@ class TensorNameMap:
827835
MODEL_TENSOR.CLS: (
828836
"classifier", # jina
829837
"classifier.dense", # roberta
838+
"pre_classifier", # distillbert
830839
),
831840

832841
MODEL_TENSOR.CLS_OUT: (

src/llama-model.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2114,7 +2114,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
21142114
case LLM_ARCH_NOMIC_BERT_MOE:
21152115
{
21162116
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
2117-
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0);
2117+
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED);
21182118

21192119
if (arch == LLM_ARCH_BERT) {
21202120
pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0);
@@ -5885,8 +5885,10 @@ struct llm_build_bert : public llm_graph_context {
58855885
inpL = build_inp_embd(model.tok_embd);
58865886

58875887
// token types are hardcoded to zero ("Sentence A")
5888-
ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
5889-
inpL = ggml_add(ctx0, inpL, type_row0);
5888+
if (model.type_embd) {
5889+
ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
5890+
inpL = ggml_add(ctx0, inpL, type_row0);
5891+
}
58905892
if (model.arch == LLM_ARCH_BERT) {
58915893
inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
58925894
}

0 commit comments

Comments
 (0)