Skip to content

server : add lora hotswap endpoint #8857

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 42 additions & 18 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,14 +684,24 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
if (arg == "--lora") {
CHECK_ARG
params.lora_adapter.emplace_back(argv[i], 1.0f);
params.lora_adapters.push_back({
std::string(argv[i]),
1.0,
});
return true;
}
if (arg == "--lora-scaled") {
CHECK_ARG
const char* lora_adapter = argv[i];
std::string lora_adapter = argv[i];
CHECK_ARG
params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
params.lora_adapters.push_back({
lora_adapter,
std::stof(argv[i]),
});
return true;
}
if (arg == "--lora-init-without-apply") {
params.lora_init_without_apply = true;
return true;
}
if (arg == "--control-vector") {
Expand Down Expand Up @@ -1654,6 +1664,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY",
"how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity });
options.push_back({ "server", " --lora-init-without-apply", "load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"});

#ifndef LOG_DISABLE_LOGS
options.push_back({ "logging" });
Expand Down Expand Up @@ -2091,17 +2102,22 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
}
}

for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
float lora_scale = std::get<1>(params.lora_adapter[i]);
auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
if (adapter == nullptr) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
// load and optionally apply lora adapters
for (auto & la : params.lora_adapters) {
llama_lora_adapter_container loaded_la;
loaded_la.path = la.path;
loaded_la.scale = la.scale;
loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str());
if (loaded_la.adapter == nullptr) {
fprintf(stderr, "%s: error: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx);
llama_free_model(model);
return iparams;
}
llama_lora_adapter_set(lctx, adapter, lora_scale);
iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters
}
if (!params.lora_init_without_apply) {
llama_lora_adapters_apply(lctx, iparams.lora_adapters);
}

if (params.ignore_eos) {
Expand Down Expand Up @@ -2140,6 +2156,15 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
return iparams;
}

void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lora_adapter_container> & lora_adapters) {
llama_lora_adapter_clear(ctx);
for (auto & la : lora_adapters) {
if (la.scale != 0.0f) {
llama_lora_adapter_set(ctx, la.adapter, la.scale);
}
}
}

struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
auto mparams = llama_model_default_params();

Expand Down Expand Up @@ -3162,19 +3187,18 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
}

fprintf(stream, "lora:\n");
for (std::tuple<std::string, float> la : params.lora_adapter) {
if (std::get<1>(la) != 1.0f) {
continue;
for (auto & la : params.lora_adapters) {
if (la.scale == 1.0f) {
fprintf(stream, " - %s\n", la.path.c_str());
}
fprintf(stream, " - %s\n", std::get<0>(la).c_str());
}
fprintf(stream, "lora_scaled:\n");
for (std::tuple<std::string, float> la : params.lora_adapter) {
if (std::get<1>(la) == 1.0f) {
continue;
for (auto & la : params.lora_adapters) {
if (la.scale != 1.0f) {
fprintf(stream, " - %s: %f\n", la.path.c_str(), la.scale);
}
fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la));
}
fprintf(stream, "lora_init_without_apply: %s # default: false\n", params.lora_init_without_apply ? "true" : "false");
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep);
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
Expand Down
19 changes: 16 additions & 3 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@

#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"

struct llama_lora_adapter_info {
std::string path;
float scale;
};

struct llama_lora_adapter_container : llama_lora_adapter_info {
struct llama_lora_adapter * adapter;
};

// build info
extern int LLAMA_BUILD_NUMBER;
extern char const * LLAMA_COMMIT;
Expand Down Expand Up @@ -126,8 +135,8 @@ struct gpt_params {
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
std::vector<llama_model_kv_override> kv_overrides;

// TODO: avoid tuple, use struct
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
std::vector<llama_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale

std::vector<llama_control_vector_load_info> control_vectors; // control vector with user defined scale

Expand Down Expand Up @@ -309,8 +318,9 @@ std::string fs_get_cache_file(const std::string & filename);
//

struct llama_init_result {
struct llama_model * model = nullptr;
struct llama_model * model = nullptr;
struct llama_context * context = nullptr;
std::vector<llama_lora_adapter_container> lora_adapters;
};

struct llama_init_result llama_init_from_gpt_params(gpt_params & params);
Expand All @@ -321,6 +331,9 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params);
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params);

// clear LoRA adapters from context, then apply new list of adapters
void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lora_adapter_container> & lora_adapters);

// Batch utils

void llama_batch_clear(struct llama_batch & batch);
Expand Down
10 changes: 5 additions & 5 deletions examples/export-lora/export-lora.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ struct lora_merge_ctx {

lora_merge_ctx(
std::string & base_fname,
std::vector<std::tuple<std::string, float>> & lora_files,
std::vector<llama_lora_adapter_info> & lora_files,
std::string & outfile,
int n_threads) : base_model(base_fname, 0), n_threads(n_threads), fout(outfile, std::ios::binary) {
fout.exceptions(std::ofstream::failbit); // fail fast on write errors
Expand All @@ -144,9 +144,9 @@ struct lora_merge_ctx {
throw std::runtime_error("split model is not yet supported");
}

for (auto lora_inp : lora_files) {
auto fname = std::get<0>(lora_inp);
auto scale = std::get<1>(lora_inp);
for (auto & lora_inp : lora_files) {
auto fname = lora_inp.path;
auto scale = lora_inp.scale;
std::unique_ptr<file_input> adapter(new file_input(fname, scale));
check_metadata_lora(adapter.get());
adapters.push_back(std::move(adapter));
Expand Down Expand Up @@ -407,7 +407,7 @@ int main(int argc, char ** argv) {

g_verbose = (params.verbosity == 1);
try {
lora_merge_ctx ctx(params.model, params.lora_adapter, params.lora_outfile, params.n_threads);
lora_merge_ctx ctx(params.model, params.lora_adapters, params.lora_outfile, params.n_threads);
ctx.run_merge();
} catch (const std::exception & err) {
fprintf(stderr, "%s\n", err.what());
Expand Down
Loading
Loading