@@ -473,6 +473,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
473
473
else { invalid_param = true ; }
474
474
return true ;
475
475
}
476
+ if (arg == " --attention" ) {
477
+ CHECK_ARG
478
+ std::string value (argv[i]);
479
+ /* */ if (value == " causal" ) { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
480
+ else if (value == " non-causal" ) { params.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL; }
481
+ else { invalid_param = true ; }
482
+ return true ;
483
+ }
476
484
if (arg == " --defrag-thold" || arg == " -dt" ) {
477
485
CHECK_ARG
478
486
params.defrag_thold = std::stof (argv[i]);
@@ -758,7 +766,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
758
766
params.cache_type_v = argv[++i];
759
767
return true ;
760
768
}
761
- if (arg == " --multiline-input" ) {
769
+ if (arg == " -mli " || arg == " - -multiline-input" ) {
762
770
params.multiline_input = true ;
763
771
return true ;
764
772
}
@@ -1395,7 +1403,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
1395
1403
options.push_back ({ " *" , " --keep N" , " number of tokens to keep from the initial prompt (default: %d, -1 = all)" , params.n_keep });
1396
1404
options.push_back ({ " *" , " --chunks N" , " max number of chunks to process (default: %d, -1 = all)" , params.n_chunks });
1397
1405
options.push_back ({ " *" , " -fa, --flash-attn" , " enable Flash Attention (default: %s)" , params.flash_attn ? " enabled" : " disabled" });
1398
- options.push_back ({ " *" , " -p, --prompt PROMPT" , " prompt to start generation with (default: '%s')" , params.prompt .c_str () });
1406
+ options.push_back ({ " *" , " -p, --prompt PROMPT" , " prompt to start generation with\n "
1407
+ " in conversation mode, this will be used as system prompt\n "
1408
+ " (default: '%s')" , params.prompt .c_str () });
1399
1409
options.push_back ({ " *" , " -f, --file FNAME" , " a file containing the prompt (default: none)" });
1400
1410
options.push_back ({ " *" , " --in-file FNAME" , " an input file (repeat to specify multiple files)" });
1401
1411
options.push_back ({ " *" , " -bf, --binary-file FNAME" , " binary file containing the prompt (default: none)" });
@@ -1410,7 +1420,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
1410
1420
" halt generation at PROMPT, return control in interactive mode\n "
1411
1421
" can be specified more than once for multiple prompts" });
1412
1422
options.push_back ({ " main" , " -sp, --special" , " special tokens output enabled (default: %s)" , params.special ? " true" : " false" });
1413
- options.push_back ({ " main" , " -cnv, --conversation" , " run in conversation mode (does not print special tokens and suffix/prefix, use default chat template) (default: %s)" , params.conversation ? " true" : " false" });
1423
+ options.push_back ({ " main" , " -cnv, --conversation" , " run in conversation mode, does not print special tokens and suffix/prefix\n "
1424
+ " if suffix/prefix are not specified, default chat template will be used\n "
1425
+ " (default: %s)" , params.conversation ? " true" : " false" });
1414
1426
options.push_back ({ " main infill" , " -i, --interactive" , " run in interactive mode (default: %s)" , params.interactive ? " true" : " false" });
1415
1427
options.push_back ({ " main infill" , " -if, --interactive-first" , " run in interactive mode and wait for input right away (default: %s)" , params.interactive_first ? " true" : " false" });
1416
1428
options.push_back ({ " main infill" , " -mli, --multiline-input" , " allows you to write or paste multiple lines without ending each in '\\ '" });
@@ -1454,6 +1466,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
1454
1466
options.push_back ({ " main" , " --cfg-scale N" , " strength of guidance (default: %.1f, 1.0 = disable)" , (double )sparams.cfg_scale });
1455
1467
options.push_back ({ " main" , " --chat-template JINJA_TEMPLATE" ,
1456
1468
" set custom jinja chat template (default: template taken from model's metadata)\n "
1469
+ " if suffix/prefix are specified, template will be disabled\n "
1457
1470
" only commonly used templates are accepted:\n "
1458
1471
" https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
1459
1472
options.push_back ({ " grammar" });
@@ -1464,8 +1477,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
1464
1477
" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead" });
1465
1478
1466
1479
options.push_back ({ " embedding" });
1467
- options.push_back ({ " embedding" , " --pooling {none,mean,cls}" ,
1480
+ options.push_back ({ " embedding" , " --pooling {none,mean,cls,last }" ,
1468
1481
" pooling type for embeddings, use model default if unspecified" });
1482
+ options.push_back ({ " embedding" , " --attention {causal,non-causal}" ,
1483
+ " attention type for embeddings, use model default if unspecified" });
1469
1484
1470
1485
options.push_back ({ " context hacking" });
1471
1486
options.push_back ({ " *" , " --rope-scaling {none,linear,yarn}" ,
@@ -2071,7 +2086,24 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
2071
2086
if (params.warmup ) {
2072
2087
LOG (" warming up the model with an empty run\n " );
2073
2088
2074
- std::vector<llama_token> tmp = { llama_token_bos (model), llama_token_eos (model), };
2089
+ std::vector<llama_token> tmp;
2090
+ llama_token bos = llama_token_bos (model);
2091
+ llama_token eos = llama_token_eos (model);
2092
+ // some models (e.g. T5) don't have a BOS token
2093
+ if (bos != -1 ) {
2094
+ tmp.push_back (bos);
2095
+ }
2096
+ tmp.push_back (eos);
2097
+
2098
+ if (llama_model_has_encoder (model)) {
2099
+ llama_encode (lctx, llama_batch_get_one (tmp.data (), tmp.size (), 0 , 0 ));
2100
+ llama_token decoder_start_token_id = llama_model_decoder_start_token (model);
2101
+ if (decoder_start_token_id == -1 ) {
2102
+ decoder_start_token_id = bos;
2103
+ }
2104
+ tmp.clear ();
2105
+ tmp.push_back (decoder_start_token_id);
2106
+ }
2075
2107
llama_decode (lctx, llama_batch_get_one (tmp.data (), std::min (tmp.size (), (size_t ) params.n_batch ), 0 , 0 ));
2076
2108
llama_kv_cache_clear (lctx);
2077
2109
llama_synchronize (lctx);
@@ -2154,6 +2186,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
2154
2186
cparams.yarn_beta_slow = params.yarn_beta_slow ;
2155
2187
cparams.yarn_orig_ctx = params.yarn_orig_ctx ;
2156
2188
cparams.pooling_type = params.pooling_type ;
2189
+ cparams.attention_type = params.attention_type ;
2157
2190
cparams.defrag_thold = params.defrag_thold ;
2158
2191
cparams.cb_eval = params.cb_eval ;
2159
2192
cparams.cb_eval_user_data = params.cb_eval_user_data ;
0 commit comments