Skip to content

Commit 9cb8537

Browse files
ggerganovhodlen
authored andcommitted
common : add HF arg helpers (ggml-org#6234)
* common : add HF arg helpers * common : remove defaults
1 parent 77c74db commit 9cb8537

File tree

2 files changed

+78
-17
lines changed

2 files changed

+78
-17
lines changed

common/common.cpp

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,22 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg,
647647
params.model = argv[i];
648648
return true;
649649
}
650+
if (arg == "-md" || arg == "--model-draft") {
651+
if (++i >= argc) {
652+
invalid_param = true;
653+
return true;
654+
}
655+
params.model_draft = argv[i];
656+
return true;
657+
}
658+
if (arg == "-a" || arg == "--alias") {
659+
if (++i >= argc) {
660+
invalid_param = true;
661+
return true;
662+
}
663+
params.model_alias = argv[i];
664+
return true;
665+
}
650666
if (arg == "-mu" || arg == "--model-url") {
651667
if (++i >= argc) {
652668
invalid_param = true;
@@ -655,20 +671,20 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg,
655671
params.model_url = argv[i];
656672
return true;
657673
}
658-
if (arg == "-md" || arg == "--model-draft") {
674+
if (arg == "-hfr" || arg == "--hf-repo") {
659675
if (++i >= argc) {
660676
invalid_param = true;
661677
return true;
662678
}
663-
params.model_draft = argv[i];
679+
params.hf_repo = argv[i];
664680
return true;
665681
}
666-
if (arg == "-a" || arg == "--alias") {
682+
if (arg == "-hff" || arg == "--hf-file") {
667683
if (++i >= argc) {
668684
invalid_param = true;
669685
return true;
670686
}
671-
params.model_alias = argv[i];
687+
params.hf_file = argv[i];
672688
return true;
673689
}
674690
if (arg == "--lora") {
@@ -1403,10 +1419,14 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
14031419
printf(" layer range to apply the control vector(s) to, start and end inclusive\n");
14041420
printf(" -m FNAME, --model FNAME\n");
14051421
printf(" model path (default: %s)\n", params.model.c_str());
1406-
printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
1407-
printf(" model download url (default: %s)\n", params.model_url.c_str());
14081422
printf(" -md FNAME, --model-draft FNAME\n");
1409-
printf(" draft model for speculative decoding\n");
1423+
printf(" draft model for speculative decoding (default: unused)\n");
1424+
printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
1425+
printf(" model download url (default: unused)\n");
1426+
printf(" -hfr REPO, --hf-repo REPO\n");
1427+
printf(" Hugging Face model repository (default: unused)\n");
1428+
printf(" -hff FILE, --hf-file FILE\n");
1429+
printf(" Hugging Face model file (default: unused)\n");
14101430
printf(" -ld LOGDIR, --logdir LOGDIR\n");
14111431
printf(" path under which to save YAML logs (no logging if unset)\n");
14121432
printf(" --override-kv KEY=TYPE:VALUE\n");
@@ -1655,8 +1675,10 @@ void llama_batch_add(
16551675

16561676
#ifdef LLAMA_USE_CURL
16571677

1658-
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model,
1659-
struct llama_model_params params) {
1678+
struct llama_model * llama_load_model_from_url(
1679+
const char * model_url,
1680+
const char * path_model,
1681+
const struct llama_model_params & params) {
16601682
// Basic validation of the model_url
16611683
if (!model_url || strlen(model_url) == 0) {
16621684
fprintf(stderr, "%s: invalid model_url\n", __func__);
@@ -1850,25 +1872,62 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha
18501872
return llama_load_model_from_file(path_model, params);
18511873
}
18521874

1875+
struct llama_model * llama_load_model_from_hf(
1876+
const char * repo,
1877+
const char * model,
1878+
const char * path_model,
1879+
const struct llama_model_params & params) {
1880+
// construct hugging face model url:
1881+
//
1882+
// --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
1883+
// https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
1884+
//
1885+
// --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
1886+
// https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
1887+
//
1888+
1889+
std::string model_url = "https://huggingface.co/";
1890+
model_url += repo;
1891+
model_url += "/resolve/main/";
1892+
model_url += model;
1893+
1894+
return llama_load_model_from_url(model_url.c_str(), path_model, params);
1895+
}
1896+
18531897
#else
18541898

1855-
struct llama_model * llama_load_model_from_url(const char * /*model_url*/, const char * /*path_model*/,
1856-
struct llama_model_params /*params*/) {
1899+
struct llama_model * llama_load_model_from_url(
1900+
const char * /*model_url*/,
1901+
const char * /*path_model*/,
1902+
const struct llama_model_params & /*params*/) {
18571903
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
18581904
return nullptr;
18591905
}
18601906

1907+
struct llama_model * llama_load_model_from_hf(
1908+
const char * /*repo*/,
1909+
const char * /*model*/,
1910+
const char * /*path_model*/,
1911+
const struct llama_model_params & /*params*/) {
1912+
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
1913+
return nullptr;
1914+
}
1915+
18611916
#endif // LLAMA_USE_CURL
18621917

18631918
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
18641919
auto mparams = llama_model_params_from_gpt_params(params);
18651920

18661921
llama_model * model = nullptr;
1867-
if (!params.model_url.empty()) {
1922+
1923+
if (!params.hf_repo.empty() && !params.hf_file.empty()) {
1924+
model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams);
1925+
} else if (!params.model_url.empty()) {
18681926
model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
18691927
} else {
18701928
model = llama_load_model_from_file(params.model.c_str(), mparams);
18711929
}
1930+
18721931
if (model == NULL) {
18731932
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
18741933
return std::make_tuple(nullptr, nullptr);
@@ -1908,7 +1967,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
19081967
}
19091968

19101969
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
1911-
const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]);
1970+
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
19121971
float lora_scale = std::get<1>(params.lora_adapter[i]);
19131972
int err = llama_model_apply_lora_from_file(model,
19141973
lora_adapter.c_str(),

common/common.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,11 @@ struct gpt_params {
8989
struct llama_sampling_params sparams;
9090

9191
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
92-
std::string model_url = ""; // model url to download
93-
std::string model_draft = ""; // draft model for speculative decoding
92+
std::string model_draft = ""; // draft model for speculative decoding
9493
std::string model_alias = "unknown"; // model alias
94+
std::string model_url = ""; // model url to download
95+
std::string hf_repo = ""; // HF repo
96+
std::string hf_file = ""; // HF file
9597
std::string prompt = "";
9698
std::string prompt_file = ""; // store the external prompt file name
9799
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
@@ -192,8 +194,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
192194
struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
193195
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
194196

195-
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model,
196-
struct llama_model_params params);
197+
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params);
198+
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params);
197199

198200
// Batch utils
199201

0 commit comments

Comments
 (0)