Skip to content

Commit d50ccb0

Browse files
committed
manual merge
2 parents 1c154e9 + dadbed9 commit d50ccb0

File tree

6 files changed

+215
-162
lines changed

6 files changed

+215
-162
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 88 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <cmath>
66
#include <ctime>
77
#include <sstream>
8+
#include <cstring>
89

910
#if defined(_MSC_VER)
1011
#pragma warning(disable: 4244 4267) // possible loss of data
@@ -121,6 +122,27 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
121122
printf("\n");
122123
}
123124

125+
std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch,
126+
int n_vocab, int n_thread) {
127+
std::vector<float> result;
128+
result.reserve(tokens.size() * n_vocab);
129+
size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch;
130+
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
131+
size_t n_tokens = tokens.size() - i_chunk * n_batch;
132+
n_tokens = std::min(n_tokens, size_t(n_batch));
133+
if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) {
134+
fprintf(stderr, "%s : failed to eval\n", __func__);
135+
return {};
136+
}
137+
138+
const auto logits = llama_get_logits(ctx);
139+
result.insert(result.end(), logits, logits + n_tokens * n_vocab);
140+
141+
n_past += n_tokens;
142+
}
143+
return result;
144+
}
145+
124146
void hellaswag_score(llama_context * ctx, const gpt_params & params) {
125147
// Calculates hellaswag score (acc_norm) from prompt
126148
//
@@ -209,50 +231,93 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
209231
double acc = 0.0f;
210232
const int n_vocab = llama_n_vocab(ctx);
211233

234+
std::vector<float> tok_logits(n_vocab);
235+
212236
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
213237

214238
// Tokenize the context to count tokens
215239
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos);
216240
size_t context_size = context_embd.size();
217241

218-
for (size_t ending_idx=0;ending_idx<4;ending_idx++) {
242+
// Do the 1st ending
243+
// In this case we include the context when evaluating
244+
auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], prepend_bos);
245+
auto query_size = query_embd.size();
246+
//printf("First query: %d\n",(int)query_size);
247+
248+
// Stop if query wont fit the ctx window
249+
if (query_size > (size_t)params.n_ctx) {
250+
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
251+
return;
252+
}
253+
254+
// Speedup small evaluations by evaluating atleast 32 tokens
255+
if (query_size < 32) {
256+
query_embd.resize(32);
257+
}
258+
259+
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
260+
if (logits.empty()) {
261+
fprintf(stderr, "%s : failed to eval\n", __func__);
262+
return;
263+
}
264+
265+
std::memcpy(tok_logits.data(), logits.data() + (context_size-1)*n_vocab, n_vocab*sizeof(float));
266+
const auto first_probs = softmax(tok_logits);
267+
268+
hs_data[task_idx].ending_logprob_count[0] = 1;
269+
hs_data[task_idx].ending_logprob[0] = std::log(first_probs[query_embd[context_size]]);
270+
271+
// Calculate the logprobs over the ending
272+
for (size_t j = context_size; j < query_size - 1; j++) {
273+
274+
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
275+
276+
const float prob = softmax(tok_logits)[query_embd[j + 1]];
277+
278+
hs_data[task_idx].ending_logprob[0] += std::log(prob);
279+
hs_data[task_idx].ending_logprob_count[0]++;
280+
}
281+
282+
// Calculate the mean token logprob for acc_norm
283+
hs_data[task_idx].ending_logprob[0] /= hs_data[task_idx].ending_logprob_count[0];
284+
285+
// Do the remaining endings
286+
// For these, we use the bare ending with n_past = context_size
287+
//
288+
for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) {
219289

220290
// Tokenize the query
221-
std::vector<int> query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[ending_idx], prepend_bos);
222-
size_t query_size = query_embd.size();
291+
query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false);
292+
query_size = query_embd.size();
223293

224294
// Stop if query wont fit the ctx window
225-
if (query_size > (size_t)params.n_ctx) {
295+
if (context_size + query_size > (size_t)params.n_ctx) {
226296
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
227297
return;
228298
}
229299

230300
// Speedup small evaluations by evaluating atleast 32 tokens
231-
if (query_size < 32) {
232-
query_embd.resize(32);
233-
}
301+
// No, resizing to 32 is actually slightly slower (at least on CUDA)
302+
//if (query_size < 32) {
303+
// query_embd.resize(32);
304+
//}
234305

235306
// Evaluate the query
236-
if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads, params.n_threads)) {
307+
logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab, params.n_threads, params.pp_threads);
308+
if (logits.empty()) {
237309
fprintf(stderr, "%s : failed to eval\n", __func__);
238310
return;
239311
}
240312

241-
const auto query_logits = llama_get_logits(ctx);
242-
std::vector<float> logits;
243-
logits.insert(logits.end(), query_logits, query_logits + query_size * n_vocab);
244-
245-
hs_data[task_idx].ending_logprob_count[ending_idx] = 0;
246-
hs_data[task_idx].ending_logprob[ending_idx] = 0.0f;
313+
hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
314+
hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);
247315

248316
// Calculate the logprobs over the ending
249-
for (size_t j = context_size-1; j < query_size - 1; j++) {
250-
// Calculate probability of next token, given the previous ones.
251-
const std::vector<float> tok_logits(
252-
logits.begin() + (j + 0) * n_vocab,
253-
logits.begin() + (j + 1) * n_vocab);
317+
for (size_t j = 0; j < query_size - 1; j++) {
318+
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
254319

255-
const float prob = softmax(tok_logits)[query_embd[ j + 1]];
320+
const float prob = softmax(tok_logits)[query_embd[j + 1]];
256321

257322
hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob);
258323
hs_data[task_idx].ending_logprob_count[ending_idx]++;
@@ -267,9 +332,9 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
267332
}
268333

269334
// Find the ending with maximum logprob
270-
size_t ending_logprob_max_idx = -1;
271-
double ending_logprob_max_val = -INFINITY;
272-
for (size_t j=0; j < 4; j++) {
335+
size_t ending_logprob_max_idx = 0;
336+
double ending_logprob_max_val = hs_data[task_idx].ending_logprob[0];
337+
for (size_t j = 1; j < 4; j++) {
273338
if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) {
274339
ending_logprob_max_idx = j;
275340
ending_logprob_max_val = hs_data[task_idx].ending_logprob[j];

examples/server/deps.sh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ echo >> $PUBLIC/index.js # add newline
1111

1212
FILES=$(ls $PUBLIC)
1313

14+
cd $PUBLIC
1415
for FILE in $FILES; do
15-
func=$(echo $FILE | tr '.' '_')
16-
echo "generate $FILE.hpp ($func)"
17-
xxd -n $func -i $PUBLIC/$FILE > $DIR/$FILE.hpp
16+
echo "generate $FILE.hpp"
17+
18+
# use simple flag for old version of xxd
19+
xxd -i $FILE > $DIR/$FILE.hpp
1820
done

examples/server/public/index.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,12 @@
144144
import { SchemaConverter } from '/json-schema-to-grammar.mjs';
145145

146146
const session = signal({
147-
prompt: "This is a conversation between user and llama, a friendly chatbot. respond in simple markdown.",
147+
prompt: "This is a conversation between User and Llama, a friendly chatbot. Llama is helpful, kind, honest, good at writing, and never fails to answer any requests immediately and with precision.",
148148
template: "{{prompt}}\n\n{{history}}\n{{char}}:",
149149
historyTemplate: "{{name}}: {{message}}",
150150
transcript: [],
151151
type: "chat",
152-
char: "llama",
152+
char: "Llama",
153153
user: "User",
154154
})
155155

ggml-metal.metal

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1898,10 +1898,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
18981898
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
18991899
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
19001900
for (int i = 0; i < 8; i++) {
1901+
threadgroup_barrier(mem_flags::mem_device);
19011902
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
19021903
}
19031904

1904-
threadgroup_barrier(mem_flags::mem_threadgroup);
1905+
threadgroup_barrier(mem_flags::mem_device);
19051906
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
19061907
if (sgitg==0) {
19071908
for (int i = 0; i < n_rows; i++) {

0 commit comments

Comments
 (0)