16
16
#include < vector>
17
17
#include < map>
18
18
#include < unordered_map>
19
+ #include < memory>
19
20
#include < stdexcept>
20
21
21
22
#ifdef __has_include
@@ -492,7 +493,7 @@ typedef llama_buffer llama_ctx_buffer;
492
493
struct llama_trie_node {
493
494
llama_trie_node (): is_terminator(false ) {}
494
495
495
- std::unordered_map<char , llama_trie_node* > children;
496
+ std::unordered_map<char , std::unique_ptr< llama_trie_node> > children;
496
497
bool is_terminator;
497
498
};
498
499
@@ -507,24 +508,24 @@ struct llama_trie {
507
508
return ;
508
509
}
509
510
510
- llama_trie_node *ref = root_;
511
+ llama_trie_node *ref = root_. get () ;
511
512
for (char c : word) {
512
513
if (ref->children .find (c) == ref->children .end ()) {
513
- ref->children [c] = new llama_trie_node ();
514
+ ref->children [c]. reset ( new llama_trie_node () );
514
515
}
515
- ref = ref->children [c];
516
+ ref = ref->children [c]. get () ;
516
517
}
517
518
ref->is_terminator = true ;
518
519
}
519
520
520
521
// Will look for the words added to the trie within `text`. Output is the boundaries of the words found.
521
522
// Note that this trie will match the longest possible word first!
522
- std::vector<int > split (const std::string & text) const {
523
- std::map<int , llama_trie_node*> states;
524
- std::vector<int > offsets{0 };
523
+ std::vector<size_t > split (const std::string & text) const {
524
+ std::map<size_t , llama_trie_node*> states;
525
+ std::vector<size_t > offsets{0 };
525
526
526
- int skip = 0 ;
527
- for (int current = 0 ; current < text.size (); current++) {
527
+ size_t skip = 0 ;
528
+ for (size_t current = 0 ; current < text.size (); current++) {
528
529
char current_char = text[current];
529
530
if (skip > 0 && current < skip) {
530
531
// Prevents the lookahead for matching twice
@@ -538,7 +539,7 @@ struct llama_trie {
538
539
539
540
// In this case, we already have partial matches (But unfinished)
540
541
for (auto state = states.begin (); state != states.end (); ) {
541
- int start = state->first ;
542
+ size_t start = state->first ;
542
543
llama_trie_node *trie_pointer = state->second ;
543
544
if (trie_pointer->is_terminator ) {
544
545
// This is a final match, we need to reset and
@@ -549,11 +550,11 @@ struct llama_trie {
549
550
// Here we are also actively looking for other earlier partial
550
551
// matches
551
552
// "[CLS]", "L", we need to match CLS even if L is special
552
- int end = 0 ;
553
+ size_t end = 0 ;
553
554
for (const auto & look : states) {
554
- int lookstart = look.first ;
555
+ size_t lookstart = look.first ;
555
556
llama_trie_node *looktrie_pointer = look.second ;
556
- int lookahead_index = 0 ;
557
+ size_t lookahead_index = 0 ;
557
558
if (lookstart > start) {
558
559
// This partial match is later, we can stop looking
559
560
break ;
@@ -579,7 +580,7 @@ struct llama_trie {
579
580
580
581
auto looktrie_pointer_it = looktrie_pointer->children .find (next_char);
581
582
while (looktrie_pointer_it != looktrie_pointer->children .end ()) {
582
- looktrie_pointer = looktrie_pointer_it->second ;
583
+ looktrie_pointer = looktrie_pointer_it->second . get () ;
583
584
lookahead_index++;
584
585
if (looktrie_pointer->is_terminator ) {
585
586
start = lookstart;
@@ -606,7 +607,7 @@ struct llama_trie {
606
607
if (trie_pointer_it != trie_pointer->children .end ()) {
607
608
// The current character being looked at has a match within the trie
608
609
// update the pointer (it will be stored back into states later).
609
- trie_pointer = trie_pointer_it->second ;
610
+ trie_pointer = trie_pointer_it->second . get () ;
610
611
states[start] = trie_pointer;
611
612
++state;
612
613
} else {
@@ -625,18 +626,18 @@ struct llama_trie {
625
626
// start keeping track of this partial match.
626
627
auto children_it = root_->children .find (current_char);
627
628
if (current >= skip && children_it != root_->children .end ()) {
628
- states[current] = children_it->second ;
629
+ states[current] = children_it->second . get () ;
629
630
}
630
631
}
631
632
632
633
// We have a cut at the end with states.
633
634
for (const auto & state : states) {
634
- int start = state.first ;
635
+ size_t start = state.first ;
635
636
llama_trie_node *trie_pointer = state.second ;
636
637
if (trie_pointer->is_terminator ) {
637
638
// This is a final match, we need to reset and
638
639
// store the results in `offsets`.
639
- int end = text.size ();
640
+ size_t end = text.size ();
640
641
offsets.push_back (start);
641
642
offsets.push_back (end);
642
643
break ;
@@ -648,7 +649,7 @@ struct llama_trie {
648
649
}
649
650
650
651
private:
651
- llama_trie_node * root_;
652
+ std::unique_ptr< llama_trie_node> root_;
652
653
};
653
654
654
655
#endif
0 commit comments