Skip to content

Commit 8093253

Browse files
committed
take out attention_type; add in llama_set_embeddings
1 parent d4e6972 commit 8093253

File tree

5 files changed

+19
-39
lines changed

5 files changed

+19
-39
lines changed

common/common.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -546,17 +546,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
546546
else { invalid_param = true; }
547547
return true;
548548
}
549-
if (arg == "--attention") {
550-
if (++i >= argc) {
551-
invalid_param = true;
552-
return true;
553-
}
554-
std::string value(argv[i]);
555-
/**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
556-
else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL; }
557-
else { invalid_param = true; }
558-
return true;
559-
}
560549
if (arg == "--defrag-thold" || arg == "-dt") {
561550
if (++i >= argc) {
562551
invalid_param = true;
@@ -2460,7 +2449,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
24602449
cparams.yarn_beta_slow = params.yarn_beta_slow;
24612450
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
24622451
cparams.pooling_type = params.pooling_type;
2463-
cparams.attention_type = params.attention_type;
24642452
cparams.defrag_thold = params.defrag_thold;
24652453
cparams.cb_eval = params.cb_eval;
24662454
cparams.cb_eval_user_data = params.cb_eval_user_data;

common/common.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ struct gpt_params {
9494
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
9595
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
9696
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
97-
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type
9897

9998
// // sampling parameters
10099
struct llama_sampling_params sparams;

examples/gritlm/gritlm.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
4444

4545
// clear previous kv_cache values (irrelevant for embeddings)
4646
llama_kv_cache_clear(ctx);
47+
llama_set_embeddings(ctx, true);
48+
llama_set_causal_attn(ctx, false);
4749

4850
// run model
4951
llama_decode(ctx, batch);
@@ -97,6 +99,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
9799
llama_token eos_token = llama_token_eos(mdl);
98100

99101
llama_kv_cache_clear(ctx);
102+
llama_set_embeddings(ctx, false);
103+
llama_set_causal_attn(ctx, true);
104+
100105
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
101106

102107
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
@@ -165,13 +170,7 @@ int main(int argc, char * argv[]) {
165170
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
166171

167172
// create generation context
168-
llama_context * ctx_gen = llama_new_context_with_model(mdl, cparams);
169-
170-
// create embedding context
171-
cparams.embeddings = true;
172-
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
173-
cparams.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL;
174-
llama_context * ctx_emb = llama_new_context_with_model(mdl, cparams);
173+
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
175174

176175
// ### Embedding/Representation ###
177176
// samples taken from: https://github.com/ContextualAI/gritlm#basic
@@ -189,8 +188,8 @@ int main(int argc, char * argv[]) {
189188
};
190189

191190
// No need to add instruction for retrieval documents
192-
const std::vector<std::vector<float>> d_rep = encode(ctx_emb, documents, gritlm_instruction(""));
193-
const std::vector<std::vector<float>> q_rep = encode(ctx_emb, queries, gritlm_instruction(instruction));
191+
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
192+
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
194193

195194
const int n_embd = llama_n_embd(mdl);
196195

@@ -209,11 +208,10 @@ int main(int argc, char * argv[]) {
209208
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
210209
{
211210
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
212-
std::string response = generate(ctx_gen, prompt, true);
211+
std::string response = generate(ctx, prompt, true);
213212
}
214213

215-
llama_free(ctx_gen);
216-
llama_free(ctx_emb);
214+
llama_free(ctx);
217215
llama_free_model(mdl);
218216
llama_backend_free();
219217

llama.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15931,7 +15931,6 @@ struct llama_context_params llama_context_default_params() {
1593115931
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
1593215932
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
1593315933
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
15934-
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
1593515934
/*.rope_freq_base =*/ 0.0f,
1593615935
/*.rope_freq_scale =*/ 0.0f,
1593715936
/*.yarn_ext_factor =*/ -1.0f,
@@ -16173,12 +16172,7 @@ struct llama_context * llama_new_context_with_model(
1617316172
}
1617416173

1617516174
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
16176-
16177-
if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
16178-
cparams.causal_attn = hparams.causal_attn;
16179-
} else {
16180-
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
16181-
}
16175+
cparams.causal_attn = hparams.causal_attn;
1618216176

1618316177
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
1618416178
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
@@ -17914,6 +17908,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
1791417908
ctx->abort_callback_data = abort_callback_data;
1791517909
}
1791617910

17911+
void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
17912+
ctx->cparams.embeddings = embeddings;
17913+
}
17914+
1791717915
void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
1791817916
ctx->cparams.causal_attn = causal_attn;
1791917917
}

llama.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,6 @@ extern "C" {
177177
LLAMA_POOLING_TYPE_LAST = 3,
178178
};
179179

180-
enum llama_attention_type {
181-
LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1,
182-
LLAMA_ATTENTION_TYPE_CAUSAL = 0,
183-
LLAMA_ATTENTION_TYPE_NONCAUSAL = 1,
184-
};
185-
186180
enum llama_split_mode {
187181
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
188182
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
@@ -300,7 +294,6 @@ extern "C" {
300294

301295
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
302296
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
303-
enum llama_attention_type attention_type; // causal, non-causal, or unspecified
304297

305298
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
306299
float rope_freq_base; // RoPE base frequency, 0 = from model
@@ -793,6 +786,10 @@ extern "C" {
793786
// Get the number of threads used for prompt and batch processing (multiple token).
794787
LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
795788

789+
// Set whether the model is in embeddings model or not
790+
// If true, embeddings will be returned but logits will not
791+
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
792+
796793
// Set whether to use causal attention or not
797794
// If set to true, the model will only attend to the past tokens
798795
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);

0 commit comments

Comments
 (0)