Skip to content

Commit d17a809

Browse files
authored
llama : support multiple classifier outputs and labels (#13940)
1 parent 1caae7f commit d17a809

File tree

6 files changed

+101
-24
lines changed

6 files changed

+101
-24
lines changed

examples/embedding/embedding.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,24 @@ int main(int argc, char ** argv) {
236236
LOG("\n");
237237
}
238238
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
239+
const uint32_t n_cls_out = llama_model_n_cls_out(model);
240+
std::vector<std::string> cls_out_labels;
241+
242+
for (uint32_t i = 0; i < n_cls_out; i++) {
243+
const char * label = llama_model_cls_label(model, i);
244+
const std::string label_i(label == nullptr ? "" : label);
245+
cls_out_labels.emplace_back(label_i.empty() ? std::to_string(i) : label_i);
246+
}
247+
239248
for (int j = 0; j < n_embd_count; j++) {
240-
// NOTE: if you change this log - update the tests in ci/run.sh
241-
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
249+
for (uint32_t i = 0; i < n_cls_out; i++) {
250+
// NOTE: if you change this log - update the tests in ci/run.sh
251+
if (n_cls_out == 1) {
252+
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
253+
} else {
254+
LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str());
255+
}
256+
}
242257
}
243258
} else {
244259
// print the first part of the embeddings or for a single prompt, the full embedding

include/llama.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,13 @@ extern "C" {
514514
// Get the model's RoPE frequency scaling factor
515515
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
516516

517+
// Returns the number of classifier outputs (only valid for classifier models)
518+
// Undefined behavior for non-classifier models
519+
LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model);
520+
521+
// Returns label of classifier output by index (<n_cls_out). Returns nullptr if no label provided
522+
LLAMA_API const char * llama_model_cls_label(const struct llama_model * model, uint32_t i);
523+
517524
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
518525

519526
LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
@@ -992,7 +999,7 @@ extern "C" {
992999

9931000
// Get the embeddings for a sequence id
9941001
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
995-
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
1002+
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out] with the rank(s) of the sequence
9961003
// otherwise: float[n_embd] (1-dimensional)
9971004
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
9981005

src/llama-context.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -839,16 +839,17 @@ int llama_context::encode(llama_batch & inp_batch) {
839839
} break;
840840
case LLAMA_POOLING_TYPE_RANK:
841841
{
842-
// extract the rerank score - a single float per sequence
842+
// extract the rerank score - n_cls_out floats per sequence
843843
auto & embd_seq_out = embd_seq;
844+
const uint32_t n_cls_out = hparams.n_cls_out;
844845

845846
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
846847
const llama_seq_id seq_id = ubatch.seq_id[s][0];
847848
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
848849
continue;
849850
}
850-
embd_seq_out[seq_id].resize(1);
851-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
851+
embd_seq_out[seq_id].resize(n_cls_out);
852+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
852853
}
853854
} break;
854855
case LLAMA_POOLING_TYPE_UNSPECIFIED:

src/llama-model-loader.cpp

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -288,61 +288,84 @@ namespace GGUFMeta {
288288

289289
template<typename T>
290290
bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
291-
const int kid = gguf_find_key(meta.get(), key.c_str());
291+
const gguf_context * ctx = meta.get();
292+
const int kid = gguf_find_key(ctx, key.c_str());
292293

293-
if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
294+
if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
294295
if (required) {
295296
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
296297
}
297298
return false;
298299
}
299300

300301
struct GGUFMeta::ArrayInfo arr_info =
301-
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
302+
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
302303

303304
switch (arr_info.gt) {
304305
case GGUF_TYPE_UINT32:
305-
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
306-
(std::is_same<T, uint32_t>::value)); break;
307-
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
306+
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
307+
(std::is_same<T, uint32_t>::value)); break;
308+
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
309+
case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same<T, std::string>::value)); break;
308310
default:
309-
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
311+
throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
310312
}
311313

312-
result.resize(arr_info.length);
313-
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
314+
if constexpr (std::is_same<T, std::string>::value) {
315+
const size_t n_items = gguf_get_arr_n(ctx, kid);
316+
result.clear();
317+
318+
for (size_t i = 0; i < n_items; i++) {
319+
const T value = gguf_get_arr_str(ctx, kid, i);
320+
result.emplace_back(value);
321+
}
322+
} else {
323+
result.resize(arr_info.length);
324+
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
325+
}
314326

315327
return true;
316328
}
317329

318330
template<typename T, size_t N_MAX>
319331
bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
320-
const int kid = gguf_find_key(meta.get(), key.c_str());
332+
const gguf_context * ctx = meta.get();
333+
const int kid = gguf_find_key(ctx, key.c_str());
321334

322-
if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
335+
if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
323336
if (required) {
324337
throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
325338
}
326339
return false;
327340
}
328341

329342
struct GGUFMeta::ArrayInfo arr_info =
330-
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
343+
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
331344

332345
switch (arr_info.gt) {
333346
case GGUF_TYPE_UINT32:
334-
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
335-
(std::is_same<T, uint32_t>::value)); break;
336-
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
347+
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
348+
(std::is_same<T, uint32_t>::value)); break;
349+
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
350+
case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same<T, std::string>::value)); break;
337351
default:
338-
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
352+
throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
339353
}
340354

341355
if (arr_info.length > N_MAX) {
342356
throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
343357
}
344358

345-
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
359+
if constexpr (std::is_same<T, std::string>::value) {
360+
const size_t n_items = gguf_get_arr_n(ctx, kid);
361+
362+
for (size_t i = 0; i < n_items; i++) {
363+
const T value = gguf_get_arr_str(ctx, kid, i);
364+
result[i] = value;
365+
}
366+
} else {
367+
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
368+
}
346369

347370
return true;
348371
}
@@ -352,6 +375,8 @@ namespace GGUFMeta {
352375
return get_arr(llm_kv(kid), result, required);
353376
}
354377

378+
template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required);
379+
355380
template<typename T>
356381
bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
357382
auto it = kv_overrides.find(key);

src/llama-model.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
543543
uint32_t n_vocab = 0;
544544
ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
545545

546+
// for classifier models
547+
ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false);
548+
if (!classifier_labels.empty()) {
549+
hparams.n_cls_out = classifier_labels.size();
550+
}
551+
546552
// arch-specific KVs
547553
switch (arch) {
548554
case LLM_ARCH_LLAMA:
@@ -686,7 +692,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
686692
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
687693
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
688694
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
689-
ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false);
690695

691696
switch (hparams.n_layer) {
692697
case 3:
@@ -4362,6 +4367,15 @@ void llama_model::print_info() const {
43624367
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
43634368
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
43644369
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
4370+
4371+
if (!classifier_labels.empty()) {
4372+
LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out);
4373+
4374+
size_t i = 0;
4375+
for (auto label : classifier_labels) {
4376+
LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str());
4377+
}
4378+
}
43654379
}
43664380

43674381
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
@@ -13602,6 +13616,18 @@ int32_t llama_model_n_swa(const llama_model * model) {
1360213616
return model->hparams.n_swa;
1360313617
}
1360413618

13619+
uint32_t llama_model_n_cls_out(const struct llama_model * model) {
13620+
return model->hparams.n_cls_out;
13621+
}
13622+
13623+
const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) {
13624+
if (i < model->classifier_labels.size()) {
13625+
return model->classifier_labels[i].c_str();
13626+
}
13627+
13628+
return nullptr;
13629+
}
13630+
1360513631
// deprecated
1360613632
int32_t llama_n_ctx_train(const llama_model * model) {
1360713633
return llama_model_n_ctx_train(model);

src/llama-model.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,9 @@ struct llama_model {
329329
llama_hparams hparams = {};
330330
llama_vocab vocab;
331331

332+
// for classifier models
333+
std::vector<std::string> classifier_labels;
334+
332335
struct ggml_tensor * tok_embd = nullptr;
333336
struct ggml_tensor * type_embd = nullptr;
334337
struct ggml_tensor * pos_embd = nullptr;

0 commit comments

Comments
 (0)