Skip to content

Commit 4f3ba1c

Browse files
ngxsonpockers21
authored andcommitted
common : add common_remote_get_content (ggml-org#13123)
* common : add common_remote_get_content * support max size and timeout * add tests
1 parent cc00eb2 commit 4f3ba1c

File tree

3 files changed

+127
-34
lines changed

3 files changed

+127
-34
lines changed

common/arg.cpp

Lines changed: 71 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ struct common_hf_file_res {
162162

163163
#ifdef LLAMA_USE_CURL
164164

165+
bool common_has_curl() {
166+
return true;
167+
}
168+
165169
#ifdef __linux__
166170
#include <linux/limits.h>
167171
#elif defined(_WIN32)
@@ -527,64 +531,89 @@ static bool common_download_model(
527531
return true;
528532
}
529533

530-
/**
531-
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
532-
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
533-
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
534-
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
535-
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
536-
*
537-
* Return pair of <repo, file> (with "repo" already having tag removed)
538-
*
539-
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
540-
*/
541-
static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token) {
542-
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
543-
std::string tag = parts.size() > 1 ? parts.back() : "latest";
544-
std::string hf_repo = parts[0];
545-
if (string_split<std::string>(hf_repo, '/').size() != 2) {
546-
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
547-
}
548-
549-
// fetch model info from Hugging Face Hub API
534+
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params) {
550535
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
551536
curl_slist_ptr http_headers;
552-
std::string res_str;
537+
std::vector<char> res_buffer;
553538

554-
std::string model_endpoint = get_model_endpoint();
555-
556-
std::string url = model_endpoint + "v2/" + hf_repo + "/manifests/" + tag;
557539
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
558540
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
541+
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
559542
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
560543
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
561-
static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
544+
auto data_vec = static_cast<std::vector<char> *>(data);
545+
data_vec->insert(data_vec->end(), (char *)ptr, (char *)ptr + size * nmemb);
562546
return size * nmemb;
563547
};
564548
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
565-
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
549+
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_buffer);
566550
#if defined(_WIN32)
567551
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
568552
#endif
569-
if (!bearer_token.empty()) {
570-
std::string auth_header = "Authorization: Bearer " + bearer_token;
571-
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
553+
if (params.timeout > 0) {
554+
curl_easy_setopt(curl.get(), CURLOPT_TIMEOUT, params.timeout);
555+
}
556+
if (params.max_size > 0) {
557+
curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size);
572558
}
573-
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
574559
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
575-
http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json");
560+
for (const auto & header : params.headers) {
561+
http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str());
562+
}
576563
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
577564

578565
CURLcode res = curl_easy_perform(curl.get());
579566

580567
if (res != CURLE_OK) {
581-
throw std::runtime_error("error: cannot make GET request to HF API");
568+
std::string error_msg = curl_easy_strerror(res);
569+
throw std::runtime_error("error: cannot make GET request: " + error_msg);
582570
}
583571

584572
long res_code;
585-
std::string ggufFile = "";
586-
std::string mmprojFile = "";
587573
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
574+
575+
return { res_code, std::move(res_buffer) };
576+
}
577+
578+
/**
579+
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
580+
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
581+
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
582+
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
583+
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
584+
*
585+
* Return pair of <repo, file> (with "repo" already having tag removed)
586+
*
587+
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
588+
*/
589+
static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token) {
590+
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
591+
std::string tag = parts.size() > 1 ? parts.back() : "latest";
592+
std::string hf_repo = parts[0];
593+
if (string_split<std::string>(hf_repo, '/').size() != 2) {
594+
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
595+
}
596+
597+
std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
598+
599+
// headers
600+
std::vector<std::string> headers;
601+
headers.push_back("Accept: application/json");
602+
if (!bearer_token.empty()) {
603+
headers.push_back("Authorization: Bearer " + bearer_token);
604+
}
605+
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
606+
// User-Agent header is already set in common_remote_get_content, no need to set it here
607+
608+
// make the request
609+
common_remote_params params;
610+
params.headers = headers;
611+
auto res = common_remote_get_content(url, params);
612+
long res_code = res.first;
613+
std::string res_str(res.second.data(), res.second.size());
614+
std::string ggufFile;
615+
std::string mmprojFile;
616+
588617
if (res_code == 200) {
589618
// extract ggufFile.rfilename in json, using regex
590619
{
@@ -618,6 +647,10 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_
618647

619648
#else
620649

650+
bool common_has_curl() {
651+
return false;
652+
}
653+
621654
static bool common_download_file_single(const std::string &, const std::string &, const std::string &) {
622655
LOG_ERR("error: built without CURL, cannot download model from internet\n");
623656
return false;
@@ -640,6 +673,10 @@ static struct common_hf_file_res common_get_hf_file(const std::string &, const s
640673
return {};
641674
}
642675

676+
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params) {
677+
throw std::runtime_error("error: built without CURL, cannot download model from the internet");
678+
}
679+
643680
#endif // LLAMA_USE_CURL
644681

645682
//

common/arg.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,12 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
7878

7979
// function to be used by test-arg-parser
8080
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
81+
bool common_has_curl();
82+
83+
struct common_remote_params {
84+
std::vector<std::string> headers;
85+
long timeout = 0; // CURLOPT_TIMEOUT, in seconds ; 0 means no timeout
86+
long max_size = 0; // max size of the response ; unlimited if 0 ; max is 2GB
87+
};
88+
// get remote file content, returns <http_code, raw_response_body>
89+
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);

tests/test-arg-parser.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,53 @@ int main(void) {
126126
assert(params.cpuparams.n_threads == 1010);
127127
#endif // _WIN32
128128

129+
if (common_has_curl()) {
130+
printf("test-arg-parser: test curl-related functions\n\n");
131+
const char * GOOD_URL = "https://raw.githubusercontent.com/ggml-org/llama.cpp/refs/heads/master/README.md";
132+
const char * BAD_URL = "https://www.google.com/404";
133+
const char * BIG_FILE = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin";
134+
135+
{
136+
printf("test-arg-parser: test good URL\n\n");
137+
auto res = common_remote_get_content(GOOD_URL, {});
138+
assert(res.first == 200);
139+
assert(res.second.size() > 0);
140+
std::string str(res.second.data(), res.second.size());
141+
assert(str.find("llama.cpp") != std::string::npos);
142+
}
143+
144+
{
145+
printf("test-arg-parser: test bad URL\n\n");
146+
auto res = common_remote_get_content(BAD_URL, {});
147+
assert(res.first == 404);
148+
}
149+
150+
{
151+
printf("test-arg-parser: test max size error\n");
152+
common_remote_params params;
153+
params.max_size = 1;
154+
try {
155+
common_remote_get_content(GOOD_URL, params);
156+
assert(false && "it should throw an error");
157+
} catch (std::exception & e) {
158+
printf(" expected error: %s\n\n", e.what());
159+
}
160+
}
161+
162+
{
163+
printf("test-arg-parser: test timeout error\n");
164+
common_remote_params params;
165+
params.timeout = 1;
166+
try {
167+
common_remote_get_content(BIG_FILE, params);
168+
assert(false && "it should throw an error");
169+
} catch (std::exception & e) {
170+
printf(" expected error: %s\n\n", e.what());
171+
}
172+
}
173+
} else {
174+
printf("test-arg-parser: no curl, skipping curl-related functions\n");
175+
}
129176

130177
printf("test-arg-parser: all tests OK\n\n");
131178
}

0 commit comments

Comments
 (0)