Skip to content

Commit 272cd0e

Browse files
committed
common : update lora
ggml-ci
1 parent 8d117a5 commit 272cd0e

File tree

8 files changed

+40
-40
lines changed

8 files changed

+40
-40
lines changed

common/arg.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,15 +1512,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
15121512
{"--lora"}, "FNAME",
15131513
"path to LoRA adapter (can be repeated to use multiple adapters)",
15141514
[](common_params & params, const std::string & value) {
1515-
params.lora_adapters.push_back({ std::string(value), 1.0 });
1515+
params.lora_adapters.push_back({ std::string(value), 1.0, nullptr });
15161516
}
15171517
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
15181518
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
15191519
add_opt(common_arg(
15201520
{"--lora-scaled"}, "FNAME", "SCALE",
15211521
"path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
15221522
[](common_params & params, const std::string & fname, const std::string & scale) {
1523-
params.lora_adapters.push_back({ fname, std::stof(scale) });
1523+
params.lora_adapters.push_back({ fname, std::stof(scale), nullptr });
15241524
}
15251525
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
15261526
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));

common/common.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -922,20 +922,21 @@ struct common_init_result common_init_from_params(common_params & params) {
922922

923923
// load and optionally apply lora adapters
924924
for (auto & la : params.lora_adapters) {
925-
common_lora_adapter_container loaded_la;
926-
loaded_la.path = la.path;
927-
loaded_la.scale = la.scale;
928-
loaded_la.adapter.reset(llama_lora_adapter_init(model, la.path.c_str()));
929-
if (loaded_la.adapter == nullptr) {
925+
llama_lora_adapter_ptr lora;
926+
lora.reset(llama_lora_adapter_init(model, la.path.c_str()));
927+
if (lora == nullptr) {
930928
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
931929
llama_free(lctx);
932930
llama_free_model(model);
933931
return iparams;
934932
}
935-
iparams.lora_adapters.emplace_back(std::move(loaded_la)); // copy to list of loaded adapters
933+
934+
la.ptr = lora.get();
935+
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
936936
}
937+
937938
if (!params.lora_init_without_apply) {
938-
common_lora_adapters_apply(lctx, iparams.lora_adapters);
939+
common_lora_adapters_apply(lctx, params.lora_adapters);
939940
}
940941

941942
if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
@@ -1002,11 +1003,11 @@ struct common_init_result common_init_from_params(common_params & params) {
10021003
return iparams;
10031004
}
10041005

1005-
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters) {
1006+
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora) {
10061007
llama_lora_adapter_clear(ctx);
1007-
for (auto & la : lora_adapters) {
1008+
for (auto & la : lora) {
10081009
if (la.scale != 0.0f) {
1009-
llama_lora_adapter_set(ctx, la.adapter.get(), la.scale);
1010+
llama_lora_adapter_set(ctx, la.ptr, la.scale);
10101011
}
10111012
}
10121013
}

common/common.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,12 @@
2424

2525
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
2626

27+
// TODO: "lora_adapter" is tautology
2728
struct common_lora_adapter_info {
2829
std::string path;
2930
float scale;
30-
};
3131

32-
struct common_lora_adapter_container : common_lora_adapter_info {
33-
llama_lora_adapter_ptr adapter;
32+
struct llama_lora_adapter * ptr;
3433
};
3534

3635
using llama_tokens = std::vector<llama_token>;
@@ -478,11 +477,12 @@ std::string fs_get_cache_file(const std::string & filename);
478477
// Model utils
479478
//
480479

480+
// note: defines object's lifetime
481481
struct common_init_result {
482482
llama_model_ptr model;
483483
llama_context_ptr context;
484484

485-
std::vector<common_lora_adapter_container> lora_adapters;
485+
std::vector<llama_lora_adapter_ptr> lora;
486486
};
487487

488488
struct common_init_result common_init_from_params(common_params & params);
@@ -504,7 +504,7 @@ struct llama_model * common_load_model_from_hf(
504504
const struct llama_model_params & params);
505505

506506
// clear LoRA adapters from context, then apply new list of adapters
507-
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters);
507+
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora);
508508

509509
//
510510
// Batch utils

examples/server/server.cpp

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ struct slot_params {
9898
int64_t t_max_prompt_ms = -1; // TODO: implement
9999
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
100100

101-
std::vector<common_lora_adapter_container> lora;
101+
std::vector<common_lora_adapter_info> lora;
102102

103103
std::vector<std::string> antiprompt;
104104
std::vector<std::string> response_fields;
@@ -198,15 +198,14 @@ struct server_task {
198198
bool metrics_reset_bucket = false;
199199

200200
// used by SERVER_TASK_TYPE_SET_LORA
201-
std::vector<common_lora_adapter_container> set_lora;
201+
std::vector<common_lora_adapter_info> set_lora;
202202

203203
server_task(server_task_type type) : type(type) {}
204204

205205
static slot_params params_from_json_cmpl(
206206
const llama_model * model,
207207
const llama_context * ctx,
208208
const common_params & params_base,
209-
const std::vector<common_lora_adapter_container> & lora_base,
210209
const json & data) {
211210
slot_params params;
212211

@@ -265,12 +264,12 @@ struct server_task {
265264

266265
if (data.contains("lora")) {
267266
if (data.at("lora").is_array()) {
268-
params.lora = parse_lora_request(lora_base, data.at("lora"));
267+
params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora"));
269268
} else {
270269
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
271270
}
272271
} else {
273-
params.lora = lora_base;
272+
params.lora = params_base.lora_adapters;
274273
}
275274

276275
// TODO: add more sanity checks for the input parameters
@@ -1132,7 +1131,7 @@ struct server_slot {
11321131

11331132
common_speculative * spec = nullptr;
11341133

1135-
std::vector<common_lora_adapter_container> lora;
1134+
std::vector<common_lora_adapter_info> lora;
11361135

11371136
// the index relative to completion multi-task request
11381137
size_t index = 0;
@@ -1633,8 +1632,6 @@ struct server_context {
16331632
llama_model * model = nullptr;
16341633
llama_context * ctx = nullptr;
16351634

1636-
std::vector<common_lora_adapter_container> lora;
1637-
16381635
llama_model * model_dft = nullptr;
16391636

16401637
llama_context_params cparams_dft;
@@ -1687,8 +1684,6 @@ struct server_context {
16871684
model = llama_init.model.get();
16881685
ctx = llama_init.context.get();
16891686

1690-
lora = std::move(llama_init.lora_adapters);
1691-
16921687
if (model == nullptr) {
16931688
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
16941689
return false;
@@ -1883,7 +1878,7 @@ struct server_context {
18831878
if (!are_lora_equal(task.params.lora, slot.lora)) {
18841879
// if lora is changed, we cannot reuse cached tokens
18851880
slot.cache_tokens.clear();
1886-
slot.lora = std::move(task.params.lora);
1881+
slot.lora = task.params.lora;
18871882
}
18881883

18891884
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
@@ -2577,7 +2572,7 @@ struct server_context {
25772572
} break;
25782573
case SERVER_TASK_TYPE_SET_LORA:
25792574
{
2580-
lora = std::move(task.set_lora);
2575+
params_base.lora_adapters = std::move(task.set_lora);
25812576
auto res = std::make_unique<server_task_result_apply_lora>();
25822577
res->id = task.id;
25832578
queue_results.send(std::move(res));
@@ -3656,7 +3651,6 @@ int main(int argc, char ** argv) {
36563651
ctx_server.model,
36573652
ctx_server.ctx,
36583653
ctx_server.params_base,
3659-
ctx_server.lora,
36603654
data);
36613655
task.id_selected_slot = json_value(data, "id_slot", -1);
36623656

@@ -4083,8 +4077,9 @@ int main(int argc, char ** argv) {
40834077

40844078
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
40854079
json result = json::array();
4086-
for (size_t i = 0; i < ctx_server.lora.size(); ++i) {
4087-
auto & lora = ctx_server.lora[i];
4080+
const auto & loras = ctx_server.params_base.lora_adapters;
4081+
for (size_t i = 0; i < loras.size(); ++i) {
4082+
auto & lora = loras[i];
40884083
result.push_back({
40894084
{"id", i},
40904085
{"path", lora.path},
@@ -4103,7 +4098,7 @@ int main(int argc, char ** argv) {
41034098
}
41044099
server_task task(SERVER_TASK_TYPE_SET_LORA);
41054100
task.id = ctx_server.queue_tasks.get_new_id();
4106-
task.set_lora = parse_lora_request(ctx_server.lora, body);
4101+
task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body);
41074102
ctx_server.queue_results.add_waiting_task_id(task.id);
41084103
ctx_server.queue_tasks.post(task);
41094104

examples/server/utils.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -799,25 +799,25 @@ static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx
799799
}
800800

801801
static bool are_lora_equal(
802-
const std::vector<common_lora_adapter_container> & l1,
803-
const std::vector<common_lora_adapter_container> & l2) {
802+
const std::vector<common_lora_adapter_info> & l1,
803+
const std::vector<common_lora_adapter_info> & l2) {
804804
if (l1.size() != l2.size()) {
805805
return false;
806806
}
807807
for (size_t i = 0; i < l1.size(); ++i) {
808808
// we don't check lora.path to reduce the time complexity
809-
if (l1[i].scale != l2[i].scale || l1[i].adapter != l2[i].adapter) {
809+
if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) {
810810
return false;
811811
}
812812
}
813813
return true;
814814
}
815815

816-
// parse lora config from JSON request, returned a copy of base_lora with updated scale
817-
static std::vector<common_lora_adapter_container> parse_lora_request(
818-
const std::vector<common_lora_adapter_container> & base_lora,
816+
// parse lora config from JSON request, returned a copy of lora_base with updated scale
817+
static std::vector<common_lora_adapter_info> parse_lora_request(
818+
const std::vector<common_lora_adapter_info> & lora_base,
819819
const json & data) {
820-
std::vector<common_lora_adapter_container> lora(base_lora);
820+
std::vector<common_lora_adapter_info> lora(lora_base);
821821
int max_idx = lora.size();
822822

823823
// clear existing value

src/llama-impl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <cinttypes>
66
#include <climits>
77
#include <cstdarg>
8+
#include <cstring>
89
#include <vector>
910
#include <sstream>
1011

src/llama-model-loader.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
#include "ggml.h"
44

5+
#include <array>
56
#include <cinttypes>
7+
#include <cstring>
68
#include <future>
79

810
const char * llama_file_version_name(llama_fver version) {

src/llama-model-loader.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <cstddef>
1212
#include <map>
13+
#include <stdexcept>
1314
#include <unordered_map>
1415

1516
using llama_buf_map = std::unordered_map<uint32_t, ggml_backend_buffer_t>;

0 commit comments

Comments
 (0)