Skip to content

Commit e0e3aa2

Browse files
huydt84huydt-bti
andauthored
llama : add support for BertForSequenceClassification reranker (#13858)
* convert: add support for BertForSequenceClassification * add support for reranking using BertForSequenceClassification * merge checks of eos and sep * fix lint --------- Co-authored-by: dinhhuy <[email protected]>
1 parent aa6dff0 commit e0e3aa2

File tree

4 files changed

+43
-22
lines changed

4 files changed

+43
-22
lines changed

common/common.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -903,13 +903,16 @@ struct common_init_result common_init_from_params(common_params & params) {
903903
ok = false;
904904
}
905905

906-
if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
907-
LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__);
908-
ok = false;
909-
}
906+
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
907+
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
910908

911-
if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) {
912-
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
909+
if (!has_eos && !has_sep) {
910+
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
911+
ok = false;
912+
} else if (!has_eos) {
913+
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
914+
} else if (!has_sep) {
915+
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
913916
ok = false;
914917
}
915918

convert_hf_to_gguf.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3682,7 +3682,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
36823682
return [(self.map_tensor_name(name), data_torch)]
36833683

36843684

3685-
@ModelBase.register("BertModel", "BertForMaskedLM", "CamembertModel")
3685+
@ModelBase.register("BertModel", "BertForMaskedLM", "CamembertModel", "BertForSequenceClassification")
36863686
class BertModel(TextModel):
36873687
model_arch = gguf.MODEL_ARCH.BERT
36883688

@@ -3745,6 +3745,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
37453745
if name.startswith("cls.seq_relationship"):
37463746
return []
37473747

3748+
# For BertForSequenceClassification (direct projection layer)
3749+
if name == "classifier.weight":
3750+
name = "classifier.out_proj.weight"
3751+
3752+
if name == "classifier.bias":
3753+
name = "classifier.out_proj.bias"
3754+
37483755
return [(self.map_tensor_name(name), data_torch)]
37493756

37503757
def _xlmroberta_tokenizer_init(self) -> None:

src/llama-graph.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,20 +1562,25 @@ void llm_graph_context::build_pooling(
15621562
ggml_tensor * inp_cls = build_inp_cls();
15631563
inp = ggml_get_rows(ctx0, inp, inp_cls);
15641564

1565-
// classification head
1566-
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1567-
GGML_ASSERT(cls != nullptr);
1568-
GGML_ASSERT(cls_b != nullptr);
1569-
1570-
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
1571-
cur = ggml_tanh(ctx0, cur);
1572-
1573-
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1574-
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1575-
if (cls_out) {
1565+
if (cls != nullptr && cls_b != nullptr) {
1566+
// classification head
1567+
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1568+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
1569+
cur = ggml_tanh(ctx0, cur);
1570+
1571+
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1572+
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1573+
if (cls_out) {
1574+
GGML_ASSERT(cls_out_b != nullptr);
1575+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
1576+
}
1577+
} else if (cls_out) {
1578+
// Single layer classification head (direct projection)
1579+
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
15761580
GGML_ASSERT(cls_out_b != nullptr);
1577-
1578-
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
1581+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
1582+
} else {
1583+
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
15791584
}
15801585
} break;
15811586
default:

tools/server/utils.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,19 @@ static size_t validate_utf8(const std::string& text) {
264264
static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
265265
llama_tokens result;
266266

267+
// Get EOS token - use SEP token as fallback if EOS is not available
268+
llama_token eos_token = llama_vocab_eos(vocab);
269+
if (eos_token == LLAMA_TOKEN_NULL) {
270+
eos_token = llama_vocab_sep(vocab);
271+
}
272+
267273
result.reserve(doc.size() + query.size() + 4);
268274
result.push_back(llama_vocab_bos(vocab));
269275
result.insert(result.end(), query.begin(), query.end());
270-
result.push_back(llama_vocab_eos(vocab));
276+
result.push_back(eos_token);
271277
result.push_back(llama_vocab_sep(vocab));
272278
result.insert(result.end(), doc.begin(), doc.end());
273-
result.push_back(llama_vocab_eos(vocab));
279+
result.push_back(eos_token);
274280

275281
return result;
276282
}

0 commit comments

Comments
 (0)