@@ -190,6 +190,12 @@ int32_t cpu_get_num_math() {
190
190
// CLI argument parsing
191
191
//
192
192
193
+ void gpt_params_handle_hf_token (gpt_params & params) {
194
+ if (params.hf_token .empty () && std::getenv (" HF_TOKEN" )) {
195
+ params.hf_token = std::getenv (" HF_TOKEN" );
196
+ }
197
+ }
198
+
193
199
void gpt_params_handle_model_default (gpt_params & params) {
194
200
if (!params.hf_repo .empty ()) {
195
201
// short-hand to avoid specifying --hf-file -> default it to --model
@@ -237,6 +243,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
237
243
238
244
gpt_params_handle_model_default (params);
239
245
246
+ gpt_params_handle_hf_token (params);
247
+
240
248
if (params.escape ) {
241
249
string_process_escapes (params.prompt );
242
250
string_process_escapes (params.input_prefix );
@@ -652,6 +660,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
652
660
params.model_url = argv[i];
653
661
return true ;
654
662
}
663
+ if (arg == " -hft" || arg == " --hf-token" ) {
664
+ if (++i >= argc) {
665
+ invalid_param = true ;
666
+ return true ;
667
+ }
668
+ params.hf_token = argv[i];
669
+ return true ;
670
+ }
655
671
if (arg == " -hfr" || arg == " --hf-repo" ) {
656
672
CHECK_ARG
657
673
params.hf_repo = argv[i];
@@ -1576,6 +1592,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
1576
1592
options.push_back ({ " *" , " -mu, --model-url MODEL_URL" , " model download url (default: unused)" });
1577
1593
options.push_back ({ " *" , " -hfr, --hf-repo REPO" , " Hugging Face model repository (default: unused)" });
1578
1594
options.push_back ({ " *" , " -hff, --hf-file FILE" , " Hugging Face model file (default: unused)" });
1595
+ options.push_back ({ " *" , " -hft, --hf-token TOKEN" , " Hugging Face access token (default: value from HF_TOKEN environment variable)" });
1579
1596
1580
1597
options.push_back ({ " retrieval" });
1581
1598
options.push_back ({ " retrieval" , " --context-file FNAME" , " file to load context from (repeat to specify multiple files)" });
@@ -2015,9 +2032,9 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
2015
2032
llama_model * model = nullptr ;
2016
2033
2017
2034
if (!params.hf_repo .empty () && !params.hf_file .empty ()) {
2018
- model = llama_load_model_from_hf (params.hf_repo .c_str (), params.hf_file .c_str (), params.model .c_str (), mparams);
2035
+ model = llama_load_model_from_hf (params.hf_repo .c_str (), params.hf_file .c_str (), params.model .c_str (), params. hf_token . c_str (), mparams);
2019
2036
} else if (!params.model_url .empty ()) {
2020
- model = llama_load_model_from_url (params.model_url .c_str (), params.model .c_str (), mparams);
2037
+ model = llama_load_model_from_url (params.model_url .c_str (), params.model .c_str (), params. hf_token . c_str (), mparams);
2021
2038
} else {
2022
2039
model = llama_load_model_from_file (params.model .c_str (), mparams);
2023
2040
}
@@ -2205,7 +2222,7 @@ static bool starts_with(const std::string & str, const std::string & prefix) {
2205
2222
return str.rfind (prefix, 0 ) == 0 ;
2206
2223
}
2207
2224
2208
- static bool llama_download_file (const std::string & url, const std::string & path) {
2225
+ static bool llama_download_file (const std::string & url, const std::string & path, const std::string & hf_token ) {
2209
2226
2210
2227
// Initialize libcurl
2211
2228
std::unique_ptr<CURL, decltype (&curl_easy_cleanup)> curl (curl_easy_init (), &curl_easy_cleanup);
@@ -2220,6 +2237,15 @@ static bool llama_download_file(const std::string & url, const std::string & pat
2220
2237
curl_easy_setopt (curl.get (), CURLOPT_URL, url.c_str ());
2221
2238
curl_easy_setopt (curl.get (), CURLOPT_FOLLOWLOCATION, 1L );
2222
2239
2240
+ // Check if hf-token or bearer-token was specified
2241
+ if (!hf_token.empty ()) {
2242
+ std::string auth_header = " Authorization: Bearer " ;
2243
+ auth_header += hf_token.c_str ();
2244
+ struct curl_slist *http_headers = NULL ;
2245
+ http_headers = curl_slist_append (http_headers, auth_header.c_str ());
2246
+ curl_easy_setopt (curl.get (), CURLOPT_HTTPHEADER, http_headers);
2247
+ }
2248
+
2223
2249
#if defined(_WIN32)
2224
2250
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
2225
2251
// operating system. Currently implemented under MS-Windows.
@@ -2415,14 +2441,15 @@ static bool llama_download_file(const std::string & url, const std::string & pat
2415
2441
struct llama_model * llama_load_model_from_url (
2416
2442
const char * model_url,
2417
2443
const char * path_model,
2444
+ const char * hf_token,
2418
2445
const struct llama_model_params & params) {
2419
2446
// Basic validation of the model_url
2420
2447
if (!model_url || strlen (model_url) == 0 ) {
2421
2448
fprintf (stderr, " %s: invalid model_url\n " , __func__);
2422
2449
return NULL ;
2423
2450
}
2424
2451
2425
- if (!llama_download_file (model_url, path_model)) {
2452
+ if (!llama_download_file (model_url, path_model, hf_token )) {
2426
2453
return NULL ;
2427
2454
}
2428
2455
@@ -2470,14 +2497,14 @@ struct llama_model * llama_load_model_from_url(
2470
2497
// Prepare download in parallel
2471
2498
std::vector<std::future<bool >> futures_download;
2472
2499
for (int idx = 1 ; idx < n_split; idx++) {
2473
- futures_download.push_back (std::async (std::launch::async, [&split_prefix, &split_url_prefix, &n_split](int download_idx) -> bool {
2500
+ futures_download.push_back (std::async (std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token ](int download_idx) -> bool {
2474
2501
char split_path[PATH_MAX] = {0 };
2475
2502
llama_split_path (split_path, sizeof (split_path), split_prefix, download_idx, n_split);
2476
2503
2477
2504
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0 };
2478
2505
llama_split_path (split_url, sizeof (split_url), split_url_prefix, download_idx, n_split);
2479
2506
2480
- return llama_download_file (split_url, split_path);
2507
+ return llama_download_file (split_url, split_path, hf_token );
2481
2508
}, idx));
2482
2509
}
2483
2510
@@ -2496,6 +2523,7 @@ struct llama_model * llama_load_model_from_hf(
2496
2523
const char * repo,
2497
2524
const char * model,
2498
2525
const char * path_model,
2526
+ const char * hf_token,
2499
2527
const struct llama_model_params & params) {
2500
2528
// construct hugging face model url:
2501
2529
//
@@ -2511,14 +2539,15 @@ struct llama_model * llama_load_model_from_hf(
2511
2539
model_url += " /resolve/main/" ;
2512
2540
model_url += model;
2513
2541
2514
- return llama_load_model_from_url (model_url.c_str (), path_model, params);
2542
+ return llama_load_model_from_url (model_url.c_str (), path_model, hf_token, params);
2515
2543
}
2516
2544
2517
2545
#else
2518
2546
2519
2547
struct llama_model * llama_load_model_from_url (
2520
2548
const char * /* model_url*/ ,
2521
2549
const char * /* path_model*/ ,
2550
+ const char * /* hf_token*/ ,
2522
2551
const struct llama_model_params & /* params*/ ) {
2523
2552
fprintf (stderr, " %s: llama.cpp built without libcurl, downloading from an url not supported.\n " , __func__);
2524
2553
return nullptr ;
@@ -2528,6 +2557,7 @@ struct llama_model * llama_load_model_from_hf(
2528
2557
const char * /* repo*/ ,
2529
2558
const char * /* model*/ ,
2530
2559
const char * /* path_model*/ ,
2560
+ const char * /* hf_token*/ ,
2531
2561
const struct llama_model_params & /* params*/ ) {
2532
2562
fprintf (stderr, " %s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n " , __func__);
2533
2563
return nullptr ;
0 commit comments