@@ -98,7 +98,7 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
98
98
sampler_tester tester (probs, probs_expected);
99
99
100
100
DUMP (&tester.cur_p );
101
- tester.apply (llama_sampler_init_top_p (p, 1 ));
101
+ tester.apply (llama_sampler_init_top_p (p, 0 ));
102
102
tester.apply (llama_sampler_init_dist (0 ));
103
103
DUMP (&tester.cur_p );
104
104
@@ -109,7 +109,7 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
109
109
sampler_tester tester (probs, probs_expected);
110
110
111
111
DUMP (&tester.cur_p );
112
- tester.apply (llama_sampler_init_min_p (p, 1 ));
112
+ tester.apply (llama_sampler_init_min_p (p, 0 ));
113
113
tester.apply (llama_sampler_init_dist (0 ));
114
114
DUMP (&tester.cur_p );
115
115
@@ -130,7 +130,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
130
130
sampler_tester tester (probs, probs_expected);
131
131
132
132
DUMP (&tester.cur_p );
133
- tester.apply (llama_sampler_init_typical (p, 1 ));
133
+ tester.apply (llama_sampler_init_typical (p, 0 ));
134
134
DUMP (&tester.cur_p );
135
135
136
136
tester.check ();
@@ -332,6 +332,7 @@ int main(void) {
332
332
test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .7f , 0 .3f /0 .7f }, 0 .74f );
333
333
test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .4f }, 0 .76f );
334
334
test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .4f }, 1 .00f );
335
+ test_min_p ({0 .1f , 0 .2f , 0 .3f , 0 .4f }, {0 .4f /0 .4f }, 1 .05f );
335
336
336
337
printf (" XTC should:\n " );
337
338
test_xtc ({0 .4f , 0 .3f , 0 .2f , 0 .1f }, {0 .1f }, 0 .99f , 0 .09f );
@@ -341,8 +342,8 @@ int main(void) {
341
342
printf (" XTC should not:\n " );
342
343
test_xtc ({0 .4f , 0 .3f , 0 .2f , 0 .1f }, {0 .4f , 0 .3f , 0 .2f , 0 .1f }, 0 .99f , 0 .39f );
343
344
344
- test_typical ({0 .97f , 0 .01f , 0 .01f , 0 .01f }, {0 .97f }, 0 .5f );
345
- test_typical ({0 .4f , 0 .2f , 0 .2f , 0 .2f }, {0 .2f , 0 .2f , 0 .2f }, 0 .5f );
345
+ test_typical ({0 .97f , 0 .01f , 0 .01f , 0 .01f }, {0 .97f }, 0 .5f );
346
+ test_typical ({0 .4f , 0 .2f , 0 .2f , 0 .2f }, {0 .2f , 0 .2f , 0 .2f }, 0 .5f );
346
347
347
348
test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 }, {0 .25f , 0 .25f , 0 .25f , 0 .25f , 0 }, 50 .0f , 0 .0f , 0 .0f );
348
349
test_penalties ({0 .2f , 0 .2f , 0 .2f , 0 .2f , 0 .2f }, {0 , 1 , 2 }, {0 .5f , 0 .5f , 0 , 0 , 0 }, 50 .0f , 0 .0f , 0 .0f );
0 commit comments