Skip to content

Commit 3dd397b

Browse files
ikawrakowKawrakow
authored andcommitted
Adding some imatrix tools (ggml-org#5302)
* imatrix: adding --combine and --continue-from * imatrix: be able to start from a specific chunk --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 1b55cd9 commit 3dd397b

File tree

1 file changed

+112
-4
lines changed

1 file changed

+112
-4
lines changed

examples/imatrix/imatrix.cpp

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class IMatrixCollector {
3636
void set_parameters(StatParams&& params) { m_params = std::move(params); }
3737
bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data);
3838
void save_imatrix() const;
39+
bool load_imatrix(const char * file_name, bool add);
40+
static bool load_imatrix(const char * file_name, std::unordered_map<std::string, Stats>& imatrix);
3941
private:
4042
std::unordered_map<std::string, Stats> m_stats;
4143
StatParams m_params;
@@ -189,6 +191,57 @@ void IMatrixCollector::save_imatrix(const char * fname) const {
189191
}
190192
}
191193

194+
bool IMatrixCollector::load_imatrix(const char * imatrix_file, std::unordered_map<std::string, Stats>& imatrix_data) {
195+
std::ifstream in(imatrix_file, std::ios::binary);
196+
if (!in) {
197+
printf("%s: failed to open %s\n",__func__,imatrix_file);
198+
return false;
199+
}
200+
int n_entries;
201+
in.read((char*)&n_entries, sizeof(n_entries));
202+
if (in.fail() || n_entries < 1) {
203+
printf("%s: no data in file %s\n", __func__, imatrix_file);
204+
return false;
205+
}
206+
for (int i = 0; i < n_entries; ++i) {
207+
int len; in.read((char *)&len, sizeof(len));
208+
std::vector<char> name_as_vec(len+1);
209+
in.read((char *)name_as_vec.data(), len);
210+
if (in.fail()) {
211+
printf("%s: failed reading name for entry %d from %s\n",__func__,i+1,imatrix_file);
212+
return false;
213+
}
214+
name_as_vec[len] = 0;
215+
std::string name{name_as_vec.data()};
216+
auto& e = imatrix_data[std::move(name)];
217+
int ncall;
218+
in.read((char*)&ncall, sizeof(ncall));
219+
int nval;
220+
in.read((char *)&nval, sizeof(nval));
221+
if (in.fail() || nval < 1) {
222+
printf("%s: failed reading number of values for entry %d\n",__func__,i);
223+
imatrix_data = {};
224+
return false;
225+
}
226+
e.values.resize(nval);
227+
in.read((char*)e.values.data(), nval*sizeof(float));
228+
if (in.fail()) {
229+
printf("%s: failed reading data for entry %d\n",__func__,i);
230+
imatrix_data = {};
231+
return false;
232+
}
233+
e.ncall = ncall;
234+
}
235+
return true;
236+
}
237+
238+
bool IMatrixCollector::load_imatrix(const char * file_name, bool add) {
239+
if (!add) {
240+
m_stats.clear();
241+
}
242+
return load_imatrix(file_name, m_stats);
243+
}
244+
192245
static IMatrixCollector g_collector;
193246

194247
static bool ik_collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) {
@@ -269,7 +322,7 @@ static void process_logits(
269322
}
270323
}
271324

272-
static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl) {
325+
static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl, int from_chunk) {
273326

274327
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
275328
const int n_ctx = llama_n_ctx(ctx);
@@ -282,6 +335,15 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool
282335
auto tim2 = std::chrono::high_resolution_clock::now();
283336
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
284337

338+
if (from_chunk > 0) {
339+
if (size_t((from_chunk + 2)*n_ctx) >= tokens.size()) {
340+
fprintf(stderr, "%s: there will be not enough tokens left after removing %d chunks\n", __func__, from_chunk);
341+
return false;
342+
}
343+
fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, from_chunk, from_chunk*n_ctx);
344+
tokens.erase(tokens.begin(), tokens.begin() + from_chunk*n_ctx);
345+
}
346+
285347
if (int(tokens.size()) < 2*n_ctx) {
286348
fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n",__func__,2*n_ctx,
287349
n_ctx);
@@ -402,7 +464,10 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool
402464
int main(int argc, char ** argv) {
403465

404466
StatParams sparams;
467+
std::string prev_result_file;
468+
std::string combine_files;
405469
bool compute_ppl = true;
470+
int from_chunk = 0;
406471
std::vector<char*> args;
407472
args.push_back(argv[0]);
408473
int iarg = 1;
@@ -423,6 +488,13 @@ int main(int argc, char ** argv) {
423488
compute_ppl = false;
424489
} else if (arg == "--keep-imatrix") {
425490
sparams.keep_every = std::stoi(argv[++iarg]);
491+
} else if (arg == "--continue-from") {
492+
prev_result_file = argv[++iarg];
493+
} else if (arg == "--combine") {
494+
combine_files = argv[++iarg];
495+
}
496+
else if (arg == "--from-chunk") {
497+
from_chunk = std::stoi(argv[++iarg]);
426498
} else {
427499
args.push_back(argv[iarg]);
428500
}
@@ -436,14 +508,50 @@ int main(int argc, char ** argv) {
436508
}
437509
}
438510

511+
g_collector.set_parameters(std::move(sparams));
512+
513+
if (!combine_files.empty()) {
514+
std::vector<std::string> files;
515+
size_t pos = 0;
516+
while (true) {
517+
auto new_pos = combine_files.find(',', pos);
518+
if (new_pos != std::string::npos) {
519+
files.emplace_back(combine_files.substr(pos, new_pos - pos));
520+
pos = new_pos + 1;
521+
} else {
522+
files.emplace_back(combine_files.substr(pos));
523+
break;
524+
}
525+
}
526+
if (files.size() < 2) {
527+
fprintf(stderr, "You must provide at least two comma separated files to use --combine\n");
528+
return 1;
529+
}
530+
printf("Combining the following %d files\n", int(files.size()));
531+
for (auto& file : files) {
532+
printf(" %s\n", file.c_str());
533+
if (!g_collector.load_imatrix(file.c_str(), true)) {
534+
fprintf(stderr, "Failed to load %s\n", file.c_str());
535+
return 1;
536+
}
537+
}
538+
g_collector.save_imatrix();
539+
return 0;
540+
}
541+
542+
if (!prev_result_file.empty()) {
543+
if (!g_collector.load_imatrix(prev_result_file.c_str(), false)) {
544+
fprintf(stderr, "=============== Failed to load %s\n", prev_result_file.c_str());
545+
return 1;
546+
}
547+
}
548+
439549
gpt_params params;
440550
params.n_batch = 512;
441551
if (!gpt_params_parse(args.size(), args.data(), params)) {
442552
return 1;
443553
}
444554

445-
g_collector.set_parameters(std::move(sparams));
446-
447555
params.logits_all = true;
448556
params.n_batch = std::min(params.n_batch, params.n_ctx);
449557

@@ -495,7 +603,7 @@ int main(int argc, char ** argv) {
495603
fprintf(stderr, "%s\n", get_system_info(params).c_str());
496604
}
497605

498-
bool OK = compute_imatrix(ctx, params, compute_ppl);
606+
bool OK = compute_imatrix(ctx, params, compute_ppl, from_chunk);
499607
if (!OK) {
500608
return 1;
501609
}

0 commit comments

Comments
 (0)