16
16
#include < vector>
17
17
#include < map>
18
18
#include < unordered_map>
19
+ #include < memory>
19
20
#include < stdexcept>
20
21
21
22
#ifdef __has_include
@@ -546,7 +547,7 @@ typedef llama_buffer llama_ctx_buffer;
546
547
struct llama_trie_node {
547
548
llama_trie_node (): is_terminator(false ) {}
548
549
549
- std::unordered_map<char , llama_trie_node* > children;
550
+ std::unordered_map<char , std::unique_ptr< llama_trie_node> > children;
550
551
bool is_terminator;
551
552
};
552
553
@@ -561,24 +562,24 @@ struct llama_trie {
561
562
return ;
562
563
}
563
564
564
- llama_trie_node *ref = root_;
565
+ llama_trie_node *ref = root_. get () ;
565
566
for (char c : word) {
566
567
if (ref->children .find (c) == ref->children .end ()) {
567
- ref->children [c] = new llama_trie_node ();
568
+ ref->children [c]. reset ( new llama_trie_node () );
568
569
}
569
- ref = ref->children [c];
570
+ ref = ref->children [c]. get () ;
570
571
}
571
572
ref->is_terminator = true ;
572
573
}
573
574
574
575
// Will look for the words added to the trie within `text`. Output is the boundaries of the words found.
575
576
// Note that this trie will match the longest possible word first!
576
- std::vector<int > split (const std::string & text) const {
577
- std::map<int , llama_trie_node*> states;
578
- std::vector<int > offsets{0 };
577
+ std::vector<size_t > split (const std::string & text) const {
578
+ std::map<size_t , llama_trie_node*> states;
579
+ std::vector<size_t > offsets{0 };
579
580
580
- int skip = 0 ;
581
- for (int current = 0 ; current < text.size (); current++) {
581
+ size_t skip = 0 ;
582
+ for (size_t current = 0 ; current < text.size (); current++) {
582
583
char current_char = text[current];
583
584
if (skip > 0 && current < skip) {
584
585
// Prevents the lookahead for matching twice
@@ -592,7 +593,7 @@ struct llama_trie {
592
593
593
594
// In this case, we already have partial matches (But unfinished)
594
595
for (auto state = states.begin (); state != states.end (); ) {
595
- int start = state->first ;
596
+ size_t start = state->first ;
596
597
llama_trie_node *trie_pointer = state->second ;
597
598
if (trie_pointer->is_terminator ) {
598
599
// This is a final match, we need to reset and
@@ -603,11 +604,11 @@ struct llama_trie {
603
604
// Here we are also actively looking for other earlier partial
604
605
// matches
605
606
// "[CLS]", "L", we need to match CLS even if L is special
606
- int end = 0 ;
607
+ size_t end = 0 ;
607
608
for (const auto & look : states) {
608
- int lookstart = look.first ;
609
+ size_t lookstart = look.first ;
609
610
llama_trie_node *looktrie_pointer = look.second ;
610
- int lookahead_index = 0 ;
611
+ size_t lookahead_index = 0 ;
611
612
if (lookstart > start) {
612
613
// This partial match is later, we can stop looking
613
614
break ;
@@ -633,7 +634,7 @@ struct llama_trie {
633
634
634
635
auto looktrie_pointer_it = looktrie_pointer->children .find (next_char);
635
636
while (looktrie_pointer_it != looktrie_pointer->children .end ()) {
636
- looktrie_pointer = looktrie_pointer_it->second ;
637
+ looktrie_pointer = looktrie_pointer_it->second . get () ;
637
638
lookahead_index++;
638
639
if (looktrie_pointer->is_terminator ) {
639
640
start = lookstart;
@@ -660,7 +661,7 @@ struct llama_trie {
660
661
if (trie_pointer_it != trie_pointer->children .end ()) {
661
662
// The current character being looked at has a match within the trie
662
663
// update the pointer (it will be stored back into states later).
663
- trie_pointer = trie_pointer_it->second ;
664
+ trie_pointer = trie_pointer_it->second . get () ;
664
665
states[start] = trie_pointer;
665
666
++state;
666
667
} else {
@@ -679,18 +680,18 @@ struct llama_trie {
679
680
// start keeping track of this partial match.
680
681
auto children_it = root_->children .find (current_char);
681
682
if (current >= skip && children_it != root_->children .end ()) {
682
- states[current] = children_it->second ;
683
+ states[current] = children_it->second . get () ;
683
684
}
684
685
}
685
686
686
687
// We have a cut at the end with states.
687
688
for (const auto & state : states) {
688
- int start = state.first ;
689
+ size_t start = state.first ;
689
690
llama_trie_node *trie_pointer = state.second ;
690
691
if (trie_pointer->is_terminator ) {
691
692
// This is a final match, we need to reset and
692
693
// store the results in `offsets`.
693
- int end = text.size ();
694
+ size_t end = text.size ();
694
695
offsets.push_back (start);
695
696
offsets.push_back (end);
696
697
break ;
@@ -702,7 +703,7 @@ struct llama_trie {
702
703
}
703
704
704
705
private:
705
- llama_trie_node * root_;
706
+ std::unique_ptr< llama_trie_node> root_;
706
707
};
707
708
708
709
#endif
0 commit comments