@@ -319,6 +319,10 @@ class HttpClient {
319
319
public:
320
320
int init (const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
321
321
const bool progress, std::string * response_str = nullptr ) {
322
+ if (std::filesystem::exists (output_file)) {
323
+ return 0 ;
324
+ }
325
+
322
326
std::string output_file_partial;
323
327
curl = curl_easy_init ();
324
328
if (!curl) {
@@ -563,8 +567,8 @@ class LlamaData {
563
567
564
568
private:
565
569
#ifdef LLAMA_USE_CURL
566
- int download (const std::string & url, const std::vector<std:: string> & headers , const std::string & output_file ,
567
- const bool progress , std::string * response_str = nullptr ) {
570
+ int download (const std::string & url, const std::string & output_file , const bool progress ,
571
+ const std::vector<std::string> & headers = {} , std::string * response_str = nullptr ) {
568
572
HttpClient http;
569
573
if (http.init (url, headers, output_file, progress, response_str)) {
570
574
return 1 ;
@@ -573,57 +577,92 @@ class LlamaData {
573
577
return 0 ;
574
578
}
575
579
#else
576
- int download (const std::string &, const std::vector<std:: string> &, const std::string &, const bool ,
580
+ int download (const std::string &, const std::string &, const bool , const std::vector<std:: string> & = {} ,
577
581
std::string * = nullptr ) {
578
582
printe (" %s: llama.cpp built without libcurl, downloading from an url not supported.\n " , __func__);
579
583
return 1 ;
580
584
}
581
585
#endif
582
586
583
- int huggingface_dl (const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
587
+ // Helper function to handle model tag extraction and URL construction
588
+ std::pair<std::string, std::string> extract_model_and_tag (std::string & model, const std::string & base_url) {
589
+ std::string model_tag = " latest" ;
590
+ const size_t colon_pos = model.find (' :' );
591
+ if (colon_pos != std::string::npos) {
592
+ model_tag = model.substr (colon_pos + 1 );
593
+ model = model.substr (0 , colon_pos);
594
+ }
595
+
596
+ std::string url = base_url + model + " /manifests/" + model_tag;
597
+
598
+ return { model, url };
599
+ }
600
+
601
+ // Helper function to download and parse the manifest
602
+ int download_and_parse_manifest (const std::string & url, const std::vector<std::string> & headers,
603
+ nlohmann::json & manifest) {
604
+ std::string manifest_str;
605
+ int ret = download (url, " " , false , headers, &manifest_str);
606
+ if (ret) {
607
+ return ret;
608
+ }
609
+
610
+ manifest = nlohmann::json::parse (manifest_str);
611
+
612
+ return 0 ;
613
+ }
614
+
615
+ int huggingface_dl (std::string & model, const std::string & bn) {
584
616
// Find the second occurrence of '/' after protocol string
585
617
size_t pos = model.find (' /' );
586
618
pos = model.find (' /' , pos + 1 );
619
+ std::string hfr, hff;
620
+ std::vector<std::string> headers = { " User-Agent: llama-cpp" , " Accept: application/json" };
621
+ std::string url;
622
+
587
623
if (pos == std::string::npos) {
588
- return 1 ;
624
+ auto [model_name, manifest_url] = extract_model_and_tag (model, " https://huggingface.co/v2/" );
625
+ hfr = model_name;
626
+
627
+ nlohmann::json manifest;
628
+ int ret = download_and_parse_manifest (manifest_url, headers, manifest);
629
+ if (ret) {
630
+ return ret;
631
+ }
632
+
633
+ hff = manifest[" ggufFile" ][" rfilename" ];
634
+ } else {
635
+ hfr = model.substr (0 , pos);
636
+ hff = model.substr (pos + 1 );
589
637
}
590
638
591
- const std::string hfr = model.substr (0 , pos);
592
- const std::string hff = model.substr (pos + 1 );
593
- const std::string url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
594
- return download (url, headers, bn, true );
639
+ url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
640
+ return download (url, bn, true , headers);
595
641
}
596
642
597
- int ollama_dl (std::string & model, const std::vector<std::string> headers, const std::string & bn) {
643
+ int ollama_dl (std::string & model, const std::string & bn) {
644
+ const std::vector<std::string> headers = { " Accept: application/vnd.docker.distribution.manifest.v2+json" };
598
645
if (model.find (' /' ) == std::string::npos) {
599
646
model = " library/" + model;
600
647
}
601
648
602
- std::string model_tag = " latest" ;
603
- size_t colon_pos = model.find (' :' );
604
- if (colon_pos != std::string::npos) {
605
- model_tag = model.substr (colon_pos + 1 );
606
- model = model.substr (0 , colon_pos);
607
- }
608
-
609
- std::string manifest_url = " https://registry.ollama.ai/v2/" + model + " /manifests/" + model_tag;
610
- std::string manifest_str;
611
- const int ret = download (manifest_url, headers, " " , false , &manifest_str);
649
+ auto [model_name, manifest_url] = extract_model_and_tag (model, " https://registry.ollama.ai/v2/" );
650
+ nlohmann::json manifest;
651
+ int ret = download_and_parse_manifest (manifest_url, {}, manifest);
612
652
if (ret) {
613
653
return ret;
614
654
}
615
655
616
- nlohmann::json manifest = nlohmann::json::parse (manifest_str);
617
- std::string layer;
656
+ std::string layer;
618
657
for (const auto & l : manifest[" layers" ]) {
619
658
if (l[" mediaType" ] == " application/vnd.ollama.image.model" ) {
620
659
layer = l[" digest" ];
621
660
break ;
622
661
}
623
662
}
624
663
625
- std::string blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + layer;
626
- return download (blob_url, headers, bn, true );
664
+ std::string blob_url = " https://registry.ollama.ai/v2/" + model_name + " /blobs/" + layer;
665
+ return download (blob_url, bn, true , headers );
627
666
}
628
667
629
668
std::string basename (const std::string & path) {
@@ -653,22 +692,18 @@ class LlamaData {
653
692
return ret;
654
693
}
655
694
656
- const std::string bn = basename (model_);
657
- const std::vector<std::string> headers = { " --header" ,
658
- " Accept: application/vnd.docker.distribution.manifest.v2+json" };
695
+ const std::string bn = basename (model_);
659
696
if (string_starts_with (model_, " hf://" ) || string_starts_with (model_, " huggingface://" )) {
660
697
rm_until_substring (model_, " ://" );
661
- ret = huggingface_dl (model_, headers, bn);
698
+ ret = huggingface_dl (model_, bn);
662
699
} else if (string_starts_with (model_, " hf.co/" )) {
663
700
rm_until_substring (model_, " hf.co/" );
664
- ret = huggingface_dl (model_, headers, bn);
665
- } else if (string_starts_with (model_, " ollama://" )) {
666
- rm_until_substring (model_, " ://" );
667
- ret = ollama_dl (model_, headers, bn);
701
+ ret = huggingface_dl (model_, bn);
668
702
} else if (string_starts_with (model_, " https://" )) {
669
- ret = download (model_, headers, bn, true );
670
- } else {
671
- ret = ollama_dl (model_, headers, bn);
703
+ ret = download (model_, bn, true );
704
+ } else { // ollama:// or nothing
705
+ rm_until_substring (model_, " ://" );
706
+ ret = ollama_dl (model_, bn);
672
707
}
673
708
674
709
model_ = bn;
0 commit comments