Skip to content

Commit cebe433

Browse files
iamlemecNeo Zhang
authored andcommitted
llama : streamline embeddings from "non-embedding" models (ggml-org#8087)
1 parent 214dd3d commit cebe433

File tree

4 files changed

+36
-10
lines changed

4 files changed

+36
-10
lines changed

common/common.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
472472
else { invalid_param = true; }
473473
return true;
474474
}
475+
if (arg == "--attention") {
476+
CHECK_ARG
477+
std::string value(argv[i]);
478+
/**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
479+
else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL; }
480+
else { invalid_param = true; }
481+
return true;
482+
}
475483
if (arg == "--defrag-thold" || arg == "-dt") {
476484
CHECK_ARG
477485
params.defrag_thold = std::stof(argv[i]);
@@ -1468,8 +1476,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
14681476
"For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead" });
14691477

14701478
options.push_back({ "embedding" });
1471-
options.push_back({ "embedding", " --pooling {none,mean,cls}",
1479+
options.push_back({ "embedding", " --pooling {none,mean,cls,last}",
14721480
"pooling type for embeddings, use model default if unspecified" });
1481+
options.push_back({ "embedding", " --attention {causal,non-causal}",
1482+
"attention type for embeddings, use model default if unspecified" });
14731483

14741484
options.push_back({ "context hacking" });
14751485
options.push_back({ "*", " --rope-scaling {none,linear,yarn}",
@@ -2175,6 +2185,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
21752185
cparams.yarn_beta_slow = params.yarn_beta_slow;
21762186
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
21772187
cparams.pooling_type = params.pooling_type;
2188+
cparams.attention_type = params.attention_type;
21782189
cparams.defrag_thold = params.defrag_thold;
21792190
cparams.cb_eval = params.cb_eval;
21802191
cparams.cb_eval_user_data = params.cb_eval_user_data;

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ struct gpt_params {
9999
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
100100
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
101101
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
102+
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
102103

103104
// // sampling parameters
104105
struct llama_sampling_params sparams;

include/llama.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,12 @@ extern "C" {
180180
LLAMA_POOLING_TYPE_LAST = 3,
181181
};
182182

183+
enum llama_attention_type {
184+
LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1,
185+
LLAMA_ATTENTION_TYPE_CAUSAL = 0,
186+
LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1,
187+
};
188+
183189
enum llama_split_mode {
184190
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
185191
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
@@ -297,6 +303,7 @@ extern "C" {
297303

298304
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
299305
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
306+
enum llama_attention_type attention_type; // attention type to use for embeddings
300307

301308
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
302309
float rope_freq_base; // RoPE base frequency, 0 = from model

src/llama.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13841,7 +13841,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1384113841
}
1384213842
}
1384313843

13844-
if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
13844+
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
1384513845
const int64_t n_tokens = batch.n_tokens;
1384613846

1384713847
GGML_ASSERT(lctx.inp_mean);
@@ -13873,7 +13873,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1387313873
}
1387413874
}
1387513875

13876-
if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
13876+
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
1387713877
const int64_t n_tokens = batch.n_tokens;
1387813878

1387913879
GGML_ASSERT(lctx.inp_cls);
@@ -13894,7 +13894,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1389413894
}
1389513895
}
1389613896

13897-
if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
13897+
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
1389813898
const int64_t n_tokens = batch.n_tokens;
1389913899

1390013900
GGML_ASSERT(lctx.inp_cls);
@@ -14182,14 +14182,15 @@ static int llama_decode_internal(
1418214182
std::vector<llama_seq_id *> seq_id_arr;
1418314183
std::vector<std::vector<llama_seq_id>> seq_id;
1418414184

14185+
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
14186+
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
14187+
1418514188
// count outputs
14186-
if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) {
14187-
n_outputs = n_tokens_all;
14188-
} else if (batch_all.logits) {
14189+
if (batch_all.logits && !embd_pooled) {
1418914190
for (uint32_t i = 0; i < n_tokens_all; ++i) {
1419014191
n_outputs += batch_all.logits[i] != 0;
1419114192
}
14192-
} else if (lctx.logits_all) {
14193+
} else if (lctx.logits_all || embd_pooled) {
1419314194
n_outputs = n_tokens_all;
1419414195
} else {
1419514196
// keep last output only
@@ -14235,7 +14236,7 @@ static int llama_decode_internal(
1423514236
{
1423614237
int32_t n_outputs_new = 0;
1423714238

14238-
if (u_batch.logits) {
14239+
if (u_batch.logits && !embd_pooled) {
1423914240
for (uint32_t i = 0; i < n_tokens; i++) {
1424014241
n_outputs_new += u_batch.logits[i] != 0;
1424114242
}
@@ -18534,6 +18535,7 @@ struct llama_context_params llama_context_default_params() {
1853418535
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
1853518536
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
1853618537
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
18538+
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
1853718539
/*.rope_freq_base =*/ 0.0f,
1853818540
/*.rope_freq_scale =*/ 0.0f,
1853918541
/*.yarn_ext_factor =*/ -1.0f,
@@ -18786,7 +18788,6 @@ struct llama_context * llama_new_context_with_model(
1878618788
}
1878718789

1878818790
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
18789-
cparams.causal_attn = hparams.causal_attn;
1879018791

1879118792
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
1879218793
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
@@ -18796,6 +18797,12 @@ struct llama_context * llama_new_context_with_model(
1879618797
}
1879718798
}
1879818799

18800+
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
18801+
cparams.causal_attn = hparams.causal_attn;
18802+
} else {
18803+
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
18804+
}
18805+
1879918806
if (params.seed == LLAMA_DEFAULT_SEED) {
1880018807
params.seed = time(NULL);
1880118808
}

0 commit comments

Comments
 (0)