Skip to content

Commit cd156d7

Browse files
committed
Fix issues revealed by CI
1 parent a645e01 commit cd156d7

File tree

2 files changed

+25
-24
lines changed

2 files changed

+25
-24
lines changed

llama-util.h

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <vector>
1717
#include <map>
1818
#include <unordered_map>
19+
#include <memory>
1920
#include <stdexcept>
2021

2122
#ifdef __has_include
@@ -492,7 +493,7 @@ typedef llama_buffer llama_ctx_buffer;
492493
struct llama_trie_node {
493494
llama_trie_node(): is_terminator(false) {}
494495

495-
std::unordered_map<char, llama_trie_node*> children;
496+
std::unordered_map<char, std::unique_ptr<llama_trie_node>> children;
496497
bool is_terminator;
497498
};
498499

@@ -507,24 +508,24 @@ struct llama_trie {
507508
return;
508509
}
509510

510-
llama_trie_node *ref = root_;
511+
llama_trie_node *ref = root_.get();
511512
for (char c : word) {
512513
if (ref->children.find(c) == ref->children.end()) {
513-
ref->children[c] = new llama_trie_node();
514+
ref->children[c].reset(new llama_trie_node());
514515
}
515-
ref = ref->children[c];
516+
ref = ref->children[c].get();
516517
}
517518
ref->is_terminator = true;
518519
}
519520

520521
// Will look for the words added to the trie within `text`. Output is the boundaries of the words found.
521522
// Note that this trie will match the longest possible word first!
522-
std::vector<int> split(const std::string & text) const {
523-
std::map<int, llama_trie_node*> states;
524-
std::vector<int> offsets{0};
523+
std::vector<size_t> split(const std::string & text) const {
524+
std::map<size_t, llama_trie_node*> states;
525+
std::vector<size_t> offsets{0};
525526

526-
int skip = 0;
527-
for (int current = 0; current < text.size(); current++) {
527+
size_t skip = 0;
528+
for (size_t current = 0; current < text.size(); current++) {
528529
char current_char = text[current];
529530
if (skip > 0 && current < skip) {
530531
// Prevents the lookahead for matching twice
@@ -538,7 +539,7 @@ struct llama_trie {
538539

539540
// In this case, we already have partial matches (But unfinished)
540541
for (auto state = states.begin(); state != states.end(); ) {
541-
int start = state->first;
542+
size_t start = state->first;
542543
llama_trie_node *trie_pointer = state->second;
543544
if (trie_pointer->is_terminator) {
544545
// This is a final match, we need to reset and
@@ -549,11 +550,11 @@ struct llama_trie {
549550
// Here we are also actively looking for other earlier partial
550551
// matches
551552
// "[CLS]", "L", we need to match CLS even if L is special
552-
int end = 0;
553+
size_t end = 0;
553554
for (const auto & look : states) {
554-
int lookstart = look.first;
555+
size_t lookstart = look.first;
555556
llama_trie_node *looktrie_pointer = look.second;
556-
int lookahead_index = 0;
557+
size_t lookahead_index = 0;
557558
if (lookstart > start) {
558559
// This partial match is later, we can stop looking
559560
break;
@@ -579,7 +580,7 @@ struct llama_trie {
579580

580581
auto looktrie_pointer_it = looktrie_pointer->children.find(next_char);
581582
while (looktrie_pointer_it != looktrie_pointer->children.end()) {
582-
looktrie_pointer = looktrie_pointer_it->second;
583+
looktrie_pointer = looktrie_pointer_it->second.get();
583584
lookahead_index++;
584585
if (looktrie_pointer->is_terminator) {
585586
start = lookstart;
@@ -606,7 +607,7 @@ struct llama_trie {
606607
if (trie_pointer_it != trie_pointer->children.end()) {
607608
// The current character being looked at has a match within the trie
608609
// update the pointer (it will be stored back into states later).
609-
trie_pointer = trie_pointer_it->second;
610+
trie_pointer = trie_pointer_it->second.get();
610611
states[start] = trie_pointer;
611612
++state;
612613
} else {
@@ -625,18 +626,18 @@ struct llama_trie {
625626
// start keeping track of this partial match.
626627
auto children_it = root_->children.find(current_char);
627628
if (current >= skip && children_it != root_->children.end()) {
628-
states[current] = children_it->second;
629+
states[current] = children_it->second.get();
629630
}
630631
}
631632

632633
// We have a cut at the end with states.
633634
for (const auto & state : states) {
634-
int start = state.first;
635+
size_t start = state.first;
635636
llama_trie_node *trie_pointer = state.second;
636637
if (trie_pointer->is_terminator) {
637638
// This is a final match, we need to reset and
638639
// store the results in `offsets`.
639-
int end = text.size();
640+
size_t end = text.size();
640641
offsets.push_back(start);
641642
offsets.push_back(end);
642643
break;
@@ -648,7 +649,7 @@ struct llama_trie {
648649
}
649650

650651
private:
651-
llama_trie_node *root_;
652+
std::unique_ptr<llama_trie_node> root_;
652653
};
653654

654655
#endif

llama.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ struct llama_vocab {
248248

249249
llama_trie special_token_trie;
250250
std::unordered_map<token, id> special_token_to_id;
251-
size_t max_special_token_length;
251+
size_t max_special_token_length = 0;
252252
};
253253

254254
struct llama_context {
@@ -537,7 +537,7 @@ struct llama_file_loader {
537537
vocab.special_token_to_id.reserve(hparams.n_vocab_sp);
538538

539539
for (uint32_t i = 0; i < hparams.n_vocab_sp; i++) {
540-
uint32_t token_id = file.read_u32();
540+
llama_vocab::id token_id = file.read_u32();
541541
const auto & word = vocab.id_to_token[token_id].tok;
542542

543543
vocab.special_token_trie.add(word);
@@ -1961,9 +1961,9 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
19611961
return output;
19621962
}
19631963

1964-
std::vector<int> offsets = vocab.special_token_trie.split(text);
1965-
int start = 0;
1966-
for (int end : offsets) {
1964+
std::vector<size_t> offsets = vocab.special_token_trie.split(text);
1965+
size_t start = 0;
1966+
for (size_t end : offsets) {
19671967
if (start >= end) {
19681968
continue;
19691969
}

0 commit comments

Comments
 (0)