Skip to content

Commit b4e4b8a

Browse files
authored
llama : add llama_get_pooling_type function (#6862)
* add llama_get_pooling_type function * fix argument name, move with ctx funcs
1 parent 3fe847b commit b4e4b8a

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

common/common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ struct gpt_params {
8686

8787
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
8888

89-
llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
90-
llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
89+
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
90+
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
9191

9292
// // sampling parameters
9393
struct llama_sampling_params sparams;

llama.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15599,6 +15599,10 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1559915599
return LLAMA_ROPE_TYPE_NONE;
1560015600
}
1560115601

15602+
enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
15603+
return ctx->cparams.pooling_type;
15604+
}
15605+
1560215606
int32_t llama_n_vocab(const struct llama_model * model) {
1560315607
return model->hparams.n_vocab;
1560415608
}

llama.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,10 @@ extern "C" {
390390
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
391391
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
392392

393-
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
394-
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
393+
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
394+
395+
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
396+
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
395397

396398
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
397399
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);

0 commit comments

Comments
 (0)