Skip to content

Commit 27e683a

Browse files
ngxsonarthw
authored andcommitted
server : add lora hotswap endpoint (WIP) (ggml-org#8857)
* server : add lora hotswap endpoint * handle lora_no_apply * fix build * updae docs * clean up struct def * fix build * add LoRA test * fix style
1 parent fdec977 commit 27e683a

File tree

9 files changed

+251
-92
lines changed

9 files changed

+251
-92
lines changed

common/common.cpp

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -684,14 +684,24 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
684684
}
685685
if (arg == "--lora") {
686686
CHECK_ARG
687-
params.lora_adapter.emplace_back(argv[i], 1.0f);
687+
params.lora_adapters.push_back({
688+
std::string(argv[i]),
689+
1.0,
690+
});
688691
return true;
689692
}
690693
if (arg == "--lora-scaled") {
691694
CHECK_ARG
692-
const char* lora_adapter = argv[i];
695+
std::string lora_adapter = argv[i];
693696
CHECK_ARG
694-
params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
697+
params.lora_adapters.push_back({
698+
lora_adapter,
699+
std::stof(argv[i]),
700+
});
701+
return true;
702+
}
703+
if (arg == "--lora-init-without-apply") {
704+
params.lora_init_without_apply = true;
695705
return true;
696706
}
697707
if (arg == "--control-vector") {
@@ -1654,6 +1664,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
16541664
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
16551665
options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY",
16561666
"how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity });
1667+
options.push_back({ "server", " --lora-init-without-apply", "load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"});
16571668

16581669
#ifndef LOG_DISABLE_LOGS
16591670
options.push_back({ "logging" });
@@ -2091,17 +2102,22 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
20912102
}
20922103
}
20932104

2094-
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
2095-
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
2096-
float lora_scale = std::get<1>(params.lora_adapter[i]);
2097-
auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
2098-
if (adapter == nullptr) {
2099-
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
2105+
// load and optionally apply lora adapters
2106+
for (auto & la : params.lora_adapters) {
2107+
llama_lora_adapter_container loaded_la;
2108+
loaded_la.path = la.path;
2109+
loaded_la.scale = la.scale;
2110+
loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str());
2111+
if (loaded_la.adapter == nullptr) {
2112+
fprintf(stderr, "%s: error: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
21002113
llama_free(lctx);
21012114
llama_free_model(model);
21022115
return iparams;
21032116
}
2104-
llama_lora_adapter_set(lctx, adapter, lora_scale);
2117+
iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters
2118+
}
2119+
if (!params.lora_init_without_apply) {
2120+
llama_lora_adapters_apply(lctx, iparams.lora_adapters);
21052121
}
21062122

21072123
if (params.ignore_eos) {
@@ -2140,6 +2156,15 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
21402156
return iparams;
21412157
}
21422158

2159+
void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lora_adapter_container> & lora_adapters) {
2160+
llama_lora_adapter_clear(ctx);
2161+
for (auto & la : lora_adapters) {
2162+
if (la.scale != 0.0f) {
2163+
llama_lora_adapter_set(ctx, la.adapter, la.scale);
2164+
}
2165+
}
2166+
}
2167+
21432168
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
21442169
auto mparams = llama_model_default_params();
21452170

@@ -3162,19 +3187,18 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
31623187
}
31633188

31643189
fprintf(stream, "lora:\n");
3165-
for (std::tuple<std::string, float> la : params.lora_adapter) {
3166-
if (std::get<1>(la) != 1.0f) {
3167-
continue;
3190+
for (auto & la : params.lora_adapters) {
3191+
if (la.scale == 1.0f) {
3192+
fprintf(stream, " - %s\n", la.path.c_str());
31683193
}
3169-
fprintf(stream, " - %s\n", std::get<0>(la).c_str());
31703194
}
31713195
fprintf(stream, "lora_scaled:\n");
3172-
for (std::tuple<std::string, float> la : params.lora_adapter) {
3173-
if (std::get<1>(la) == 1.0f) {
3174-
continue;
3196+
for (auto & la : params.lora_adapters) {
3197+
if (la.scale != 1.0f) {
3198+
fprintf(stream, " - %s: %f\n", la.path.c_str(), la.scale);
31753199
}
3176-
fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la));
31773200
}
3201+
fprintf(stream, "lora_init_without_apply: %s # default: false\n", params.lora_init_without_apply ? "true" : "false");
31783202
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
31793203
fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep);
31803204
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);

common/common.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@
3333

3434
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
3535

36+
struct llama_lora_adapter_info {
37+
std::string path;
38+
float scale;
39+
};
40+
41+
struct llama_lora_adapter_container : llama_lora_adapter_info {
42+
struct llama_lora_adapter * adapter;
43+
};
44+
3645
// build info
3746
extern int LLAMA_BUILD_NUMBER;
3847
extern char const * LLAMA_COMMIT;
@@ -126,8 +135,8 @@ struct gpt_params {
126135
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
127136
std::vector<llama_model_kv_override> kv_overrides;
128137

129-
// TODO: avoid tuple, use struct
130-
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
138+
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
139+
std::vector<llama_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale
131140

132141
std::vector<llama_control_vector_load_info> control_vectors; // control vector with user defined scale
133142

@@ -309,8 +318,9 @@ std::string fs_get_cache_file(const std::string & filename);
309318
//
310319

311320
struct llama_init_result {
312-
struct llama_model * model = nullptr;
321+
struct llama_model * model = nullptr;
313322
struct llama_context * context = nullptr;
323+
std::vector<llama_lora_adapter_container> lora_adapters;
314324
};
315325

316326
struct llama_init_result llama_init_from_gpt_params(gpt_params & params);
@@ -321,6 +331,9 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
321331
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params);
322332
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params);
323333

334+
// clear LoRA adapters from context, then apply new list of adapters
335+
void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lora_adapter_container> & lora_adapters);
336+
324337
// Batch utils
325338

326339
void llama_batch_clear(struct llama_batch & batch);

examples/export-lora/export-lora.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ struct lora_merge_ctx {
135135

136136
lora_merge_ctx(
137137
std::string & base_fname,
138-
std::vector<std::tuple<std::string, float>> & lora_files,
138+
std::vector<llama_lora_adapter_info> & lora_files,
139139
std::string & outfile,
140140
int n_threads) : base_model(base_fname, 0), n_threads(n_threads), fout(outfile, std::ios::binary) {
141141
fout.exceptions(std::ofstream::failbit); // fail fast on write errors
@@ -144,9 +144,9 @@ struct lora_merge_ctx {
144144
throw std::runtime_error("split model is not yet supported");
145145
}
146146

147-
for (auto lora_inp : lora_files) {
148-
auto fname = std::get<0>(lora_inp);
149-
auto scale = std::get<1>(lora_inp);
147+
for (auto & lora_inp : lora_files) {
148+
auto fname = lora_inp.path;
149+
auto scale = lora_inp.scale;
150150
std::unique_ptr<file_input> adapter(new file_input(fname, scale));
151151
check_metadata_lora(adapter.get());
152152
adapters.push_back(std::move(adapter));
@@ -407,7 +407,7 @@ int main(int argc, char ** argv) {
407407

408408
g_verbose = (params.verbosity == 1);
409409
try {
410-
lora_merge_ctx ctx(params.model, params.lora_adapter, params.lora_outfile, params.n_threads);
410+
lora_merge_ctx ctx(params.model, params.lora_adapters, params.lora_outfile, params.n_threads);
411411
ctx.run_merge();
412412
} catch (const std::exception & err) {
413413
fprintf(stderr, "%s\n", err.what());

0 commit comments

Comments
 (0)