Skip to content

Commit 1e7a009

Browse files
committed
Merge branch 'master' into gguf
ggml-ci
2 parents 7a7d1ba + dadbed9 commit 1e7a009

File tree

7 files changed

+222
-167
lines changed

7 files changed

+222
-167
lines changed

README.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99

1010
Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
1111

12-
**Hot topics:**
12+
### 🚧 Incoming breaking change + refactoring:
1313

14-
- Simple web chat example: https://github.com/ggerganov/llama.cpp/pull/1998
15-
- k-quants now support super-block size of 64: https://github.com/ggerganov/llama.cpp/pull/2001
16-
- New roadmap: https://github.com/users/ggerganov/projects/7
17-
- Azure CI brainstorming: https://github.com/ggerganov/llama.cpp/discussions/1985
18-
- p1 : LLM-based code completion engine at the edge : https://github.com/ggml-org/p1/discussions/1
14+
See PR https://github.com/ggerganov/llama.cpp/pull/2398 for more info.
15+
16+
To devs: avoid making big changes to `llama.h` / `llama.cpp` until merged
17+
18+
----
1919

2020
<details>
2121
<summary>Table of Contents</summary>
@@ -99,6 +99,7 @@ as the main playground for developing new features for the [ggml](https://github
9999
- Rust: [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp)
100100
- C#/.NET: [SciSharp/LLamaSharp](https://github.com/SciSharp/LLamaSharp)
101101
- Scala 3: [donderom/llm4s](https://github.com/donderom/llm4s)
102+
- Clojure: [phronmophobic/llama.clj](https://github.com/phronmophobic/llama.clj)
102103

103104
**UI:**
104105

examples/perplexity/perplexity.cpp

Lines changed: 88 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <cmath>
66
#include <ctime>
77
#include <sstream>
8+
#include <cstring>
89

910
#if defined(_MSC_VER)
1011
#pragma warning(disable: 4244 4267) // possible loss of data
@@ -121,6 +122,27 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
121122
printf("\n");
122123
}
123124

125+
std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch,
126+
int n_vocab, int n_thread) {
127+
std::vector<float> result;
128+
result.reserve(tokens.size() * n_vocab);
129+
size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch;
130+
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
131+
size_t n_tokens = tokens.size() - i_chunk * n_batch;
132+
n_tokens = std::min(n_tokens, size_t(n_batch));
133+
if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) {
134+
fprintf(stderr, "%s : failed to eval\n", __func__);
135+
return {};
136+
}
137+
138+
const auto logits = llama_get_logits(ctx);
139+
result.insert(result.end(), logits, logits + n_tokens * n_vocab);
140+
141+
n_past += n_tokens;
142+
}
143+
return result;
144+
}
145+
124146
void hellaswag_score(llama_context * ctx, const gpt_params & params) {
125147
// Calculates hellaswag score (acc_norm) from prompt
126148
//
@@ -209,50 +231,93 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
209231
double acc = 0.0f;
210232
const int n_vocab = llama_n_vocab(ctx);
211233

234+
std::vector<float> tok_logits(n_vocab);
235+
212236
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
213237

214238
// Tokenize the context to count tokens
215239
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos);
216240
size_t context_size = context_embd.size();
217241

218-
for (size_t ending_idx=0;ending_idx<4;ending_idx++) {
242+
// Do the 1st ending
243+
// In this case we include the context when evaluating
244+
auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], prepend_bos);
245+
auto query_size = query_embd.size();
246+
//printf("First query: %d\n",(int)query_size);
247+
248+
// Stop if query wont fit the ctx window
249+
if (query_size > (size_t)params.n_ctx) {
250+
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
251+
return;
252+
}
253+
254+
// Speedup small evaluations by evaluating atleast 32 tokens
255+
if (query_size < 32) {
256+
query_embd.resize(32);
257+
}
258+
259+
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
260+
if (logits.empty()) {
261+
fprintf(stderr, "%s : failed to eval\n", __func__);
262+
return;
263+
}
264+
265+
std::memcpy(tok_logits.data(), logits.data() + (context_size-1)*n_vocab, n_vocab*sizeof(float));
266+
const auto first_probs = softmax(tok_logits);
267+
268+
hs_data[task_idx].ending_logprob_count[0] = 1;
269+
hs_data[task_idx].ending_logprob[0] = std::log(first_probs[query_embd[context_size]]);
270+
271+
// Calculate the logprobs over the ending
272+
for (size_t j = context_size; j < query_size - 1; j++) {
273+
274+
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
275+
276+
const float prob = softmax(tok_logits)[query_embd[j + 1]];
277+
278+
hs_data[task_idx].ending_logprob[0] += std::log(prob);
279+
hs_data[task_idx].ending_logprob_count[0]++;
280+
}
281+
282+
// Calculate the mean token logprob for acc_norm
283+
hs_data[task_idx].ending_logprob[0] /= hs_data[task_idx].ending_logprob_count[0];
284+
285+
// Do the remaining endings
286+
// For these, we use the bare ending with n_past = context_size
287+
//
288+
for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) {
219289

220290
// Tokenize the query
221-
std::vector<int> query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[ending_idx], prepend_bos);
222-
size_t query_size = query_embd.size();
291+
query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false);
292+
query_size = query_embd.size();
223293

224294
// Stop if query wont fit the ctx window
225-
if (query_size > (size_t)params.n_ctx) {
295+
if (context_size + query_size > (size_t)params.n_ctx) {
226296
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
227297
return;
228298
}
229299

230300
// Speedup small evaluations by evaluating atleast 32 tokens
231-
if (query_size < 32) {
232-
query_embd.resize(32);
233-
}
301+
// No, resizing to 32 is actually slightly slower (at least on CUDA)
302+
//if (query_size < 32) {
303+
// query_embd.resize(32);
304+
//}
234305

235306
// Evaluate the query
236-
if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) {
307+
logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab, params.n_threads);
308+
if (logits.empty()) {
237309
fprintf(stderr, "%s : failed to eval\n", __func__);
238310
return;
239311
}
240312

241-
const auto query_logits = llama_get_logits(ctx);
242-
std::vector<float> logits;
243-
logits.insert(logits.end(), query_logits, query_logits + query_size * n_vocab);
244-
245-
hs_data[task_idx].ending_logprob_count[ending_idx] = 0;
246-
hs_data[task_idx].ending_logprob[ending_idx] = 0.0f;
313+
hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
314+
hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);
247315

248316
// Calculate the logprobs over the ending
249-
for (size_t j = context_size-1; j < query_size - 1; j++) {
250-
// Calculate probability of next token, given the previous ones.
251-
const std::vector<float> tok_logits(
252-
logits.begin() + (j + 0) * n_vocab,
253-
logits.begin() + (j + 1) * n_vocab);
317+
for (size_t j = 0; j < query_size - 1; j++) {
318+
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
254319

255-
const float prob = softmax(tok_logits)[query_embd[ j + 1]];
320+
const float prob = softmax(tok_logits)[query_embd[j + 1]];
256321

257322
hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob);
258323
hs_data[task_idx].ending_logprob_count[ending_idx]++;
@@ -267,9 +332,9 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
267332
}
268333

269334
// Find the ending with maximum logprob
270-
size_t ending_logprob_max_idx = -1;
271-
double ending_logprob_max_val = -INFINITY;
272-
for (size_t j=0; j < 4; j++) {
335+
size_t ending_logprob_max_idx = 0;
336+
double ending_logprob_max_val = hs_data[task_idx].ending_logprob[0];
337+
for (size_t j = 1; j < 4; j++) {
273338
if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) {
274339
ending_logprob_max_idx = j;
275340
ending_logprob_max_val = hs_data[task_idx].ending_logprob[j];

examples/server/deps.sh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ echo >> $PUBLIC/index.js # add newline
1111

1212
FILES=$(ls $PUBLIC)
1313

14+
cd $PUBLIC
1415
for FILE in $FILES; do
15-
func=$(echo $FILE | tr '.' '_')
16-
echo "generate $FILE.hpp ($func)"
17-
xxd -n $func -i $PUBLIC/$FILE > $DIR/$FILE.hpp
16+
echo "generate $FILE.hpp"
17+
18+
# use simple flag for old version of xxd
19+
xxd -i $FILE > $DIR/$FILE.hpp
1820
done

examples/server/public/index.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,12 @@
144144
import { SchemaConverter } from '/json-schema-to-grammar.mjs';
145145

146146
const session = signal({
147-
prompt: "This is a conversation between user and llama, a friendly chatbot. respond in simple markdown.",
147+
prompt: "This is a conversation between User and Llama, a friendly chatbot. Llama is helpful, kind, honest, good at writing, and never fails to answer any requests immediately and with precision.",
148148
template: "{{prompt}}\n\n{{history}}\n{{char}}:",
149149
historyTemplate: "{{name}}: {{message}}",
150150
transcript: [],
151151
type: "chat",
152-
char: "llama",
152+
char: "Llama",
153153
user: "User",
154154
})
155155

ggml-metal.metal

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1898,10 +1898,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
18981898
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
18991899
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
19001900
for (int i = 0; i < 8; i++) {
1901+
threadgroup_barrier(mem_flags::mem_device);
19011902
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
19021903
}
19031904

1904-
threadgroup_barrier(mem_flags::mem_threadgroup);
1905+
threadgroup_barrier(mem_flags::mem_device);
19051906
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
19061907
if (sgitg==0) {
19071908
for (int i = 0; i < n_rows; i++) {

0 commit comments

Comments
 (0)