Skip to content

Commit f75fab4

Browse files
committed
Add beam search to server non-streaming completion posts.
1 parent ee8fbf0 commit f75fab4

File tree

1 file changed

+78
-13
lines changed

1 file changed

+78
-13
lines changed

examples/server/server.cpp

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,6 +1163,63 @@ static void log_server_request(const Request &req, const Response &res)
11631163
});
11641164
}
11651165

1166+
bool is_at_eos(llama_server_context&, llama_token const* tokens, size_t const n_tokens) {
1167+
return n_tokens && tokens[n_tokens-1] == llama_token_eos();
1168+
}
1169+
1170+
// Function matching type llama_beam_search_callback_fn_t.
1171+
// Custom callback example is called each time the beams lengths increase:
1172+
// * Show progress by printing ',' following by number of convergent beam tokens if any.
1173+
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
1174+
// This is also called when the stop condition is met.
1175+
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_state.
1176+
void beam_search_callback(void* callback_state, llama_beams_state beams_state) {
1177+
auto& llama = *static_cast<llama_server_context*>(callback_state);
1178+
// Mark beams as EOS as needed.
1179+
for (size_t i=0 ; i<beams_state.n_beams ; ++i) {
1180+
llama_beam_view& beam_view = beams_state.beam_views[i];
1181+
if (!beam_view.eos && is_at_eos(llama, beam_view.tokens, beam_view.n_tokens)) {
1182+
beam_view.eos = true;
1183+
}
1184+
}
1185+
printf(","); // Show progress
1186+
if (size_t const n = beams_state.common_prefix_length) {
1187+
llama.generated_token_probs.resize(llama.generated_token_probs.size() + n);
1188+
assert(0u < beams_state.n_beams);
1189+
llama_token const* tokens = beams_state.beam_views[0].tokens;
1190+
//std::copy(tokens, tokens + n, llama->generated_token_probs.end() - n);
1191+
auto const map = [](llama_token tok) { return completion_token_output{{},tok}; };
1192+
std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map);
1193+
printf("%lu", n);
1194+
}
1195+
fflush(stdout);
1196+
#if 0 // DEBUG: print current beams for this iteration
1197+
std::cout << "\n\nCurrent beams:\n";
1198+
for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
1199+
std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
1200+
}
1201+
#endif
1202+
}
1203+
1204+
struct token_translator {
1205+
llama_context* ctx;
1206+
char const* operator()(llama_token tok) const { return llama_token_to_str(ctx, tok); }
1207+
char const* operator()(completion_token_output cto) const { return (*this)(cto.tok); }
1208+
};
1209+
1210+
void append_to_generated_text_from_generated_token_probs(llama_server_context& llama) {
1211+
auto& gtps = llama.generated_token_probs;
1212+
auto translator = token_translator{llama.ctx};
1213+
auto add_strlen = [=](size_t sum, completion_token_output const& cto) { return sum + strlen(translator(cto)); };
1214+
size_t const len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen);
1215+
if (llama.generated_text.capacity() < llama.generated_text.size() + len) {
1216+
llama.generated_text.reserve(llama.generated_text.size() + len);
1217+
}
1218+
for (completion_token_output const& cto : gtps) {
1219+
llama.generated_text += translator(cto);
1220+
}
1221+
}
1222+
11661223
int main(int argc, char **argv)
11671224
{
11681225
// own arguments required by this example
@@ -1245,22 +1302,30 @@ int main(int argc, char **argv)
12451302
llama.beginCompletion();
12461303

12471304
if (!llama.stream) {
1248-
size_t stop_pos = std::string::npos;
1305+
if (llama.params.n_beams) {
1306+
// Fill llama.generated_token_probs vector with final beam.
1307+
llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
1308+
llama.n_past, llama.n_remain, llama.params.n_threads);
1309+
// Translate llama.generated_token_probs to llama.generated_text.
1310+
append_to_generated_text_from_generated_token_probs(llama);
1311+
} else {
1312+
size_t stop_pos = std::string::npos;
12491313

1250-
while (llama.has_next_token) {
1251-
const completion_token_output token_with_probs = llama.doCompletion();
1252-
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
1314+
while (llama.has_next_token) {
1315+
const completion_token_output token_with_probs = llama.doCompletion();
1316+
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
12531317

1254-
stop_pos = llama.findStoppingStrings(llama.generated_text,
1255-
token_text.size(), STOP_FULL);
1256-
}
1318+
stop_pos = llama.findStoppingStrings(llama.generated_text,
1319+
token_text.size(), STOP_FULL);
1320+
}
12571321

1258-
if (stop_pos == std::string::npos) {
1259-
stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
1260-
}
1261-
if (stop_pos != std::string::npos) {
1262-
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
1263-
llama.generated_text.end());
1322+
if (stop_pos == std::string::npos) {
1323+
stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
1324+
}
1325+
if (stop_pos != std::string::npos) {
1326+
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
1327+
llama.generated_text.end());
1328+
}
12641329
}
12651330

12661331
const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs);

0 commit comments

Comments
 (0)