Skip to content

Commit c397513

Browse files
dwoolworthngxson
authored andcommitted
added support for Authorization Bearer tokens when downloading model (ggml-org#8307)
* added support for Authorization Bearer tokens * removed auth_token, removed set_ function, other small fixes * Update common/common.cpp --------- Co-authored-by: Xuan Son Nguyen <[email protected]>
1 parent 827e454 commit c397513

File tree

2 files changed

+41
-9
lines changed

2 files changed

+41
-9
lines changed

common/common.cpp

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,12 @@ int32_t cpu_get_num_math() {
190190
// CLI argument parsing
191191
//
192192

193+
void gpt_params_handle_hf_token(gpt_params & params) {
194+
if (params.hf_token.empty() && std::getenv("HF_TOKEN")) {
195+
params.hf_token = std::getenv("HF_TOKEN");
196+
}
197+
}
198+
193199
void gpt_params_handle_model_default(gpt_params & params) {
194200
if (!params.hf_repo.empty()) {
195201
// short-hand to avoid specifying --hf-file -> default it to --model
@@ -237,6 +243,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
237243

238244
gpt_params_handle_model_default(params);
239245

246+
gpt_params_handle_hf_token(params);
247+
240248
if (params.escape) {
241249
string_process_escapes(params.prompt);
242250
string_process_escapes(params.input_prefix);
@@ -652,6 +660,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
652660
params.model_url = argv[i];
653661
return true;
654662
}
663+
if (arg == "-hft" || arg == "--hf-token") {
664+
if (++i >= argc) {
665+
invalid_param = true;
666+
return true;
667+
}
668+
params.hf_token = argv[i];
669+
return true;
670+
}
655671
if (arg == "-hfr" || arg == "--hf-repo") {
656672
CHECK_ARG
657673
params.hf_repo = argv[i];
@@ -1576,6 +1592,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
15761592
options.push_back({ "*", "-mu, --model-url MODEL_URL", "model download url (default: unused)" });
15771593
options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" });
15781594
options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" });
1595+
options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" });
15791596

15801597
options.push_back({ "retrieval" });
15811598
options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" });
@@ -2015,9 +2032,9 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
20152032
llama_model * model = nullptr;
20162033

20172034
if (!params.hf_repo.empty() && !params.hf_file.empty()) {
2018-
model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams);
2035+
model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
20192036
} else if (!params.model_url.empty()) {
2020-
model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
2037+
model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams);
20212038
} else {
20222039
model = llama_load_model_from_file(params.model.c_str(), mparams);
20232040
}
@@ -2205,7 +2222,7 @@ static bool starts_with(const std::string & str, const std::string & prefix) {
22052222
return str.rfind(prefix, 0) == 0;
22062223
}
22072224

2208-
static bool llama_download_file(const std::string & url, const std::string & path) {
2225+
static bool llama_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
22092226

22102227
// Initialize libcurl
22112228
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
@@ -2220,6 +2237,15 @@ static bool llama_download_file(const std::string & url, const std::string & pat
22202237
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
22212238
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
22222239

2240+
// Check if hf-token or bearer-token was specified
2241+
if (!hf_token.empty()) {
2242+
std::string auth_header = "Authorization: Bearer ";
2243+
auth_header += hf_token.c_str();
2244+
struct curl_slist *http_headers = NULL;
2245+
http_headers = curl_slist_append(http_headers, auth_header.c_str());
2246+
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers);
2247+
}
2248+
22232249
#if defined(_WIN32)
22242250
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
22252251
// operating system. Currently implemented under MS-Windows.
@@ -2415,14 +2441,15 @@ static bool llama_download_file(const std::string & url, const std::string & pat
24152441
struct llama_model * llama_load_model_from_url(
24162442
const char * model_url,
24172443
const char * path_model,
2444+
const char * hf_token,
24182445
const struct llama_model_params & params) {
24192446
// Basic validation of the model_url
24202447
if (!model_url || strlen(model_url) == 0) {
24212448
fprintf(stderr, "%s: invalid model_url\n", __func__);
24222449
return NULL;
24232450
}
24242451

2425-
if (!llama_download_file(model_url, path_model)) {
2452+
if (!llama_download_file(model_url, path_model, hf_token)) {
24262453
return NULL;
24272454
}
24282455

@@ -2470,14 +2497,14 @@ struct llama_model * llama_load_model_from_url(
24702497
// Prepare download in parallel
24712498
std::vector<std::future<bool>> futures_download;
24722499
for (int idx = 1; idx < n_split; idx++) {
2473-
futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split](int download_idx) -> bool {
2500+
futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool {
24742501
char split_path[PATH_MAX] = {0};
24752502
llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
24762503

24772504
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
24782505
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
24792506

2480-
return llama_download_file(split_url, split_path);
2507+
return llama_download_file(split_url, split_path, hf_token);
24812508
}, idx));
24822509
}
24832510

@@ -2496,6 +2523,7 @@ struct llama_model * llama_load_model_from_hf(
24962523
const char * repo,
24972524
const char * model,
24982525
const char * path_model,
2526+
const char * hf_token,
24992527
const struct llama_model_params & params) {
25002528
// construct hugging face model url:
25012529
//
@@ -2511,14 +2539,15 @@ struct llama_model * llama_load_model_from_hf(
25112539
model_url += "/resolve/main/";
25122540
model_url += model;
25132541

2514-
return llama_load_model_from_url(model_url.c_str(), path_model, params);
2542+
return llama_load_model_from_url(model_url.c_str(), path_model, hf_token, params);
25152543
}
25162544

25172545
#else
25182546

25192547
struct llama_model * llama_load_model_from_url(
25202548
const char * /*model_url*/,
25212549
const char * /*path_model*/,
2550+
const char * /*hf_token*/,
25222551
const struct llama_model_params & /*params*/) {
25232552
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
25242553
return nullptr;
@@ -2528,6 +2557,7 @@ struct llama_model * llama_load_model_from_hf(
25282557
const char * /*repo*/,
25292558
const char * /*model*/,
25302559
const char * /*path_model*/,
2560+
const char * /*hf_token*/,
25312561
const struct llama_model_params & /*params*/) {
25322562
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
25332563
return nullptr;

common/common.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ struct gpt_params {
108108
std::string model_draft = ""; // draft model for speculative decoding
109109
std::string model_alias = "unknown"; // model alias
110110
std::string model_url = ""; // model url to download
111+
std::string hf_token = ""; // HF token
111112
std::string hf_repo = ""; // HF repo
112113
std::string hf_file = ""; // HF file
113114
std::string prompt = "";
@@ -256,6 +257,7 @@ struct gpt_params {
256257
bool spm_infill = false; // suffix/prefix/middle pattern for infill
257258
};
258259

260+
void gpt_params_handle_hf_token(gpt_params & params);
259261
void gpt_params_handle_model_default(gpt_params & params);
260262

261263
bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params);
@@ -311,8 +313,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
311313
struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
312314
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
313315

314-
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params);
315-
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params);
316+
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params);
317+
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params);
316318

317319
// Batch utils
318320

0 commit comments

Comments
 (0)