Skip to content

Commit 2ada95b

Browse files
committed
Add new hf protocol for ollama
https://huggingface.co/docs/hub/en/ollama Signed-off-by: Eric Curtin <[email protected]>
1 parent 564804b commit 2ada95b

File tree

1 file changed

+22
-19
lines changed

1 file changed

+22
-19
lines changed

examples/run/run.cpp

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -563,8 +563,8 @@ class LlamaData {
563563

564564
private:
565565
#ifdef LLAMA_USE_CURL
566-
int download(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
567-
const bool progress, std::string * response_str = nullptr) {
566+
int download(const std::string & url, const std::string & output_file, const bool progress,
567+
const std::vector<std::string> & headers = {}, std::string * response_str = nullptr) {
568568
HttpClient http;
569569
if (http.init(url, headers, output_file, progress, response_str)) {
570570
return 1;
@@ -573,14 +573,20 @@ class LlamaData {
573573
return 0;
574574
}
575575
#else
576-
int download(const std::string &, const std::vector<std::string> &, const std::string &, const bool,
576+
int download(const std::string &, const std::string &, const bool, const std::vector<std::string> &,
577577
std::string * = nullptr) {
578578
printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
579579
return 1;
580580
}
581581
#endif
582582

583-
int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
583+
int huggingface_dl(const std::string & model, const std::string & bn) {
584+
std::vector<std::string> headers = {};
585+
if (model.find('/') != std::string::npos) {
586+
headers.push_back("User-Agent: llama-cpp");
587+
headers.push_back("Accept: application/json");
588+
}
589+
584590
// Find the second occurrence of '/' after protocol string
585591
size_t pos = model.find('/');
586592
pos = model.find('/', pos + 1);
@@ -591,10 +597,11 @@ class LlamaData {
591597
const std::string hfr = model.substr(0, pos);
592598
const std::string hff = model.substr(pos + 1);
593599
const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
594-
return download(url, headers, bn, true);
600+
return download(url, bn, true, headers);
595601
}
596602

597-
int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn) {
603+
int ollama_dl(std::string & model, const std::string & bn) {
604+
const std::vector<std::string> headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" };
598605
if (model.find('/') == std::string::npos) {
599606
model = "library/" + model;
600607
}
@@ -608,7 +615,7 @@ class LlamaData {
608615

609616
std::string manifest_url = "https://registry.ollama.ai/v2/" + model + "/manifests/" + model_tag;
610617
std::string manifest_str;
611-
const int ret = download(manifest_url, headers, "", false, &manifest_str);
618+
const int ret = download(manifest_url, "", false, {}, &manifest_str);
612619
if (ret) {
613620
return ret;
614621
}
@@ -623,7 +630,7 @@ class LlamaData {
623630
}
624631

625632
std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer;
626-
return download(blob_url, headers, bn, true);
633+
return download(blob_url, bn, true, headers);
627634
}
628635

629636
std::string basename(const std::string & path) {
@@ -653,22 +660,18 @@ class LlamaData {
653660
return ret;
654661
}
655662

656-
const std::string bn = basename(model_);
657-
const std::vector<std::string> headers = { "--header",
658-
"Accept: application/vnd.docker.distribution.manifest.v2+json" };
663+
const std::string bn = basename(model_);
659664
if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) {
660665
rm_until_substring(model_, "://");
661-
ret = huggingface_dl(model_, headers, bn);
666+
ret = huggingface_dl(model_, bn);
662667
} else if (string_starts_with(model_, "hf.co/")) {
663668
rm_until_substring(model_, "hf.co/");
664-
ret = huggingface_dl(model_, headers, bn);
665-
} else if (string_starts_with(model_, "ollama://")) {
666-
rm_until_substring(model_, "://");
667-
ret = ollama_dl(model_, headers, bn);
669+
ret = huggingface_dl(model_, bn);
668670
} else if (string_starts_with(model_, "https://")) {
669-
ret = download(model_, headers, bn, true);
670-
} else {
671-
ret = ollama_dl(model_, headers, bn);
671+
ret = download(model_, bn, true);
672+
} else { // ollama:// or nothing
673+
rm_until_substring(model_, "://");
674+
ret = ollama_dl(model_, bn);
672675
}
673676

674677
model_ = bn;

0 commit comments

Comments
 (0)