@@ -634,20 +634,20 @@ class LlamaData {
634
634
return path.substr (pos + 1 );
635
635
}
636
636
637
- int remove_proto (std::string & model_) {
638
- const std::string::size_type pos = model_.find (" :// " );
637
+ int rm_until_substring (std::string & model_, const std::string & substring ) {
638
+ const std::string::size_type pos = model_.find (substring );
639
639
if (pos == std::string::npos) {
640
640
return 1 ;
641
641
}
642
642
643
- model_ = model_.substr (pos + 3 ) ; // Skip past "://"
643
+ model_ = model_.substr (pos + substring. size ()) ; // Skip past the substring
644
644
return 0 ;
645
645
}
646
646
647
647
int resolve_model (std::string & model_) {
648
648
int ret = 0 ;
649
649
if (string_starts_with (model_, " file://" ) || std::filesystem::exists (model_)) {
650
- remove_proto (model_);
650
+ rm_until_substring (model_, " :// " );
651
651
652
652
return ret;
653
653
}
@@ -656,13 +656,16 @@ class LlamaData {
656
656
const std::vector<std::string> headers = { " --header" ,
657
657
" Accept: application/vnd.docker.distribution.manifest.v2+json" };
658
658
if (string_starts_with (model_, " hf://" ) || string_starts_with (model_, " huggingface://" )) {
659
- remove_proto (model_);
659
+ rm_until_substring (model_, " ://" );
660
+ ret = huggingface_dl (model_, headers, bn);
661
+ } else if (string_starts_with (model_, " hf.co/" )) {
662
+ rm_until_substring (model_, " hf.co/" );
660
663
ret = huggingface_dl (model_, headers, bn);
661
664
} else if (string_starts_with (model_, " ollama://" )) {
662
- remove_proto (model_);
665
+ rm_until_substring (model_, " :// " );
663
666
ret = ollama_dl (model_, headers, bn);
664
667
} else if (string_starts_with (model_, " https://" )) {
665
- download (model_, headers, bn, true );
668
+ ret = download (model_, headers, bn, true );
666
669
} else {
667
670
ret = ollama_dl (model_, headers, bn);
668
671
}
0 commit comments