@@ -563,8 +563,8 @@ class LlamaData {
563
563
564
564
private:
565
565
#ifdef LLAMA_USE_CURL
566
- int download (const std::string & url, const std::vector<std:: string> & headers , const std::string & output_file ,
567
- const bool progress , std::string * response_str = nullptr ) {
566
+ int download (const std::string & url, const std::string & output_file , const bool progress ,
567
+ const std::vector<std::string> & headers = {} , std::string * response_str = nullptr ) {
568
568
HttpClient http;
569
569
if (http.init (url, headers, output_file, progress, response_str)) {
570
570
return 1 ;
@@ -573,57 +573,95 @@ class LlamaData {
573
573
return 0 ;
574
574
}
575
575
#else
576
- int download (const std::string &, const std::vector<std:: string> &, const std::string &, const bool ,
576
+ int download (const std::string &, const std::string &, const bool , const std::vector<std:: string> & = {} ,
577
577
std::string * = nullptr ) {
578
578
printe (" %s: llama.cpp built without libcurl, downloading from an url not supported.\n " , __func__);
579
579
return 1 ;
580
580
}
581
581
#endif
582
582
583
- int huggingface_dl (const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
583
+ // Helper function to handle model tag extraction and URL construction
584
+ std::pair<std::string, std::string> extract_model_and_tag (std::string & model, const std::string & base_url) {
585
+ std::string model_tag = " latest" ;
586
+ const size_t colon_pos = model.find (' :' );
587
+ if (colon_pos != std::string::npos) {
588
+ model_tag = model.substr (colon_pos + 1 );
589
+ model = model.substr (0 , colon_pos);
590
+ }
591
+
592
+ std::string url = base_url + model + " /manifests/" + model_tag;
593
+
594
+ return { model, url };
595
+ }
596
+
597
+ // Helper function to download and parse the manifest
598
+ int download_and_parse_manifest (const std::string & url, const std::vector<std::string> & headers,
599
+ nlohmann::json & manifest) {
600
+ std::string manifest_str;
601
+ int ret = download (url, " " , false , headers, &manifest_str);
602
+ if (ret) {
603
+ return ret;
604
+ }
605
+
606
+ manifest = nlohmann::json::parse (manifest_str);
607
+
608
+ return 0 ;
609
+ }
610
+
611
+ int huggingface_dl (std::string & model, const std::string & bn) {
584
612
// Find the second occurrence of '/' after protocol string
585
613
size_t pos = model.find (' /' );
586
614
pos = model.find (' /' , pos + 1 );
615
+ std::string hfr, hff;
616
+ std::vector<std::string> headers = { " User-Agent: llama-cpp" , " Accept: application/json" };
617
+ std::string url;
618
+
587
619
if (pos == std::string::npos) {
588
- return 1 ;
620
+ auto [model_name, manifest_url] = extract_model_and_tag (model, " https://huggingface.co/v2/" );
621
+ hfr = model_name;
622
+
623
+ nlohmann::json manifest;
624
+ int ret = download_and_parse_manifest (manifest_url, headers, manifest);
625
+ if (ret) {
626
+ return ret;
627
+ }
628
+
629
+ hff = manifest[" ggufFile" ][" rfilename" ];
630
+ if (std::filesystem::exists (hff)) {
631
+ return 0 ;
632
+ }
633
+ } else {
634
+ hfr = model.substr (0 , pos);
635
+ hff = model.substr (pos + 1 );
589
636
}
590
637
591
- const std::string hfr = model.substr (0 , pos);
592
- const std::string hff = model.substr (pos + 1 );
593
- const std::string url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
594
- return download (url, headers, bn, true );
638
+ url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
639
+ return download (url, bn, true , headers);
595
640
}
596
641
597
- int ollama_dl (std::string & model, const std::vector<std::string> headers, const std::string & bn) {
642
+ int ollama_dl (std::string & model, const std::string & bn) {
643
+ const std::vector<std::string> headers = { " Accept: application/vnd.docker.distribution.manifest.v2+json" };
598
644
if (model.find (' /' ) == std::string::npos) {
599
645
model = " library/" + model;
600
646
}
601
647
602
- std::string model_tag = " latest" ;
603
- size_t colon_pos = model.find (' :' );
604
- if (colon_pos != std::string::npos) {
605
- model_tag = model.substr (colon_pos + 1 );
606
- model = model.substr (0 , colon_pos);
607
- }
608
-
609
- std::string manifest_url = " https://registry.ollama.ai/v2/" + model + " /manifests/" + model_tag;
610
- std::string manifest_str;
611
- const int ret = download (manifest_url, headers, " " , false , &manifest_str);
648
+ auto [model_name, manifest_url] = extract_model_and_tag (model, " https://registry.ollama.ai/v2/" );
649
+ nlohmann::json manifest;
650
+ int ret = download_and_parse_manifest (manifest_url, {}, manifest);
612
651
if (ret) {
613
652
return ret;
614
653
}
615
654
616
- nlohmann::json manifest = nlohmann::json::parse (manifest_str);
617
- std::string layer;
655
+ std::string layer;
618
656
for (const auto & l : manifest[" layers" ]) {
619
657
if (l[" mediaType" ] == " application/vnd.ollama.image.model" ) {
620
658
layer = l[" digest" ];
621
659
break ;
622
660
}
623
661
}
624
662
625
- std::string blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + layer;
626
- return download (blob_url, headers, bn, true );
663
+ std::string blob_url = " https://registry.ollama.ai/v2/" + model_name + " /blobs/" + layer;
664
+ return download (blob_url, bn, true , headers );
627
665
}
628
666
629
667
std::string basename (const std::string & path) {
@@ -653,22 +691,18 @@ class LlamaData {
653
691
return ret;
654
692
}
655
693
656
- const std::string bn = basename (model_);
657
- const std::vector<std::string> headers = { " --header" ,
658
- " Accept: application/vnd.docker.distribution.manifest.v2+json" };
694
+ const std::string bn = basename (model_);
659
695
if (string_starts_with (model_, " hf://" ) || string_starts_with (model_, " huggingface://" )) {
660
696
rm_until_substring (model_, " ://" );
661
- ret = huggingface_dl (model_, headers, bn);
697
+ ret = huggingface_dl (model_, bn);
662
698
} else if (string_starts_with (model_, " hf.co/" )) {
663
699
rm_until_substring (model_, " hf.co/" );
664
- ret = huggingface_dl (model_, headers, bn);
665
- } else if (string_starts_with (model_, " ollama://" )) {
666
- rm_until_substring (model_, " ://" );
667
- ret = ollama_dl (model_, headers, bn);
700
+ ret = huggingface_dl (model_, bn);
668
701
} else if (string_starts_with (model_, " https://" )) {
669
- ret = download (model_, headers, bn, true );
670
- } else {
671
- ret = ollama_dl (model_, headers, bn);
702
+ ret = download (model_, bn, true );
703
+ } else { // ollama:// or nothing
704
+ rm_until_substring (model_, " ://" );
705
+ ret = ollama_dl (model_, bn);
672
706
}
673
707
674
708
model_ = bn;
0 commit comments