Skip to content

Commit c1845a9

Browse files
committed
cont : fixes + add test [no ci]
1 parent 2681638 commit c1845a9

File tree

5 files changed

+86
-68
lines changed

5 files changed

+86
-68
lines changed

Makefile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ TEST_TARGETS = \
5454
tests/test-grammar-parser \
5555
tests/test-json-schema-to-grammar \
5656
tests/test-llama-grammar \
57+
tests/test-log \
5758
tests/test-model-load-cancel \
5859
tests/test-opt \
5960
tests/test-quantize-fns \
@@ -1528,6 +1529,11 @@ tests/test-llama-grammar: tests/test-llama-grammar.cpp \
15281529
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
15291530
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
15301531

1532+
tests/test-log: tests/test-log.cpp \
1533+
$(OBJ_ALL)
1534+
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
1535+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
1536+
15311537
tests/test-grammar-parser: tests/test-grammar-parser.cpp \
15321538
$(OBJ_ALL)
15331539
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

common/log.cpp

Lines changed: 40 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
#include <cstdio>
66
#include <condition_variable>
77

8-
#define LOG_MAX_MESSAGE_SIZE (256)
9-
108
#define LOG_COLORS // TMP
119

1210
#ifdef LOG_COLORS
@@ -41,11 +39,7 @@ struct gpt_log_entry {
4139
int verbosity;
4240
int64_t timestamp;
4341

44-
// static-sized message
45-
char msg_stck[LOG_MAX_MESSAGE_SIZE];
46-
47-
// if it doesn't fit in the stack, it goes here
48-
std::vector<char> msg_heap;
42+
std::vector<char> msg;
4943

5044
// signals the worker thread to stop
5145
bool is_end;
@@ -79,11 +73,7 @@ struct gpt_log_entry {
7973
}
8074
}
8175

82-
if (msg_heap.empty()) {
83-
fprintf(file, "%s", msg_stck);
84-
} else {
85-
fprintf(file, "%s", msg_heap.data());
86-
}
76+
fprintf(file, "%s", msg.data());
8777

8878
fflush(file);
8979
}
@@ -98,7 +88,6 @@ struct gpt_log {
9888
entries.resize(capacity);
9989
head = 0;
10090
tail = 0;
101-
buffer.resize(1024);
10291

10392
resume();
10493
}
@@ -111,9 +100,7 @@ struct gpt_log {
111100
}
112101

113102
private:
114-
std::mutex mtx_inp;
115-
std::mutex mtx_wrk;
116-
103+
std::mutex mtx;
117104
std::thread thrd;
118105
std::condition_variable cv;
119106

@@ -129,39 +116,27 @@ struct gpt_log {
129116
size_t head;
130117
size_t tail;
131118

132-
// print the message here before pushing
133-
std::vector<char> buffer;
119+
// worker thread copies into this
120+
gpt_log_entry cur;
134121

135122
public:
136123
void add(enum ggml_log_level level, int verbosity, const char * fmt, va_list args) {
137-
std::unique_lock<std::mutex> lock_inp(mtx_inp);
124+
std::lock_guard<std::mutex> lock(mtx);
138125

139126
if (!running) {
140127
return;
141128
}
142129

143-
const size_t n = vsnprintf(buffer.data(), buffer.size(), fmt, args);
144-
if (n >= buffer.size()) {
145-
buffer.resize(n + 1);
146-
vsnprintf(buffer.data(), buffer.size(), fmt, args);
147-
}
148-
149-
std::lock_guard<std::mutex> lock_wrk(mtx_wrk);
150-
151130
auto & entry = entries[tail];
152131

153-
if (n < LOG_MAX_MESSAGE_SIZE) {
154-
memcpy(entry.msg_stck, buffer.data(), n);
155-
entry.msg_stck[n] = '\0';
156-
entry.msg_heap.clear();
157-
} else {
158-
entry.msg_heap.resize(n + 1);
159-
memcpy(entry.msg_heap.data(), buffer.data(), n);
160-
entry.msg_heap[n] = '\0';
132+
{
133+
const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args);
134+
if (n >= entry.msg.size()) {
135+
entry.msg.resize(n + 1);
136+
vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args);
137+
}
161138
}
162139

163-
lock_inp.unlock();
164-
165140
entry.level = level;
166141
entry.verbosity = verbosity;
167142
entry.timestamp = 0;
@@ -173,20 +148,18 @@ struct gpt_log {
173148
tail = (tail + 1) % entries.size();
174149
if (tail == head) {
175150
// expand the buffer
176-
size_t new_size = entries.size() * 2;
177-
std::vector<gpt_log_entry> new_entries(new_size);
151+
std::vector<gpt_log_entry> new_entries(2*entries.size());
178152

179-
size_t new_head = 0;
180153
size_t new_tail = 0;
181154

182-
while (head != tail) {
183-
new_entries[new_tail] = entries[head];
155+
do {
156+
new_entries[new_tail] = std::move(entries[head]);
184157

185-
head = (head + 1) % entries.size();
186-
new_tail = (new_tail + 1) % new_size;
187-
}
158+
head = (head + 1) % entries.size();
159+
new_tail = (new_tail + 1);
160+
} while (head != tail);
188161

189-
head = new_head;
162+
head = 0;
190163
tail = new_tail;
191164

192165
entries = std::move(new_entries);
@@ -196,7 +169,7 @@ struct gpt_log {
196169
}
197170

198171
void resume() {
199-
std::lock_guard<std::mutex> lock_inp(mtx_inp);
172+
std::lock_guard<std::mutex> lock(mtx);
200173

201174
if (running) {
202175
return;
@@ -206,35 +179,37 @@ struct gpt_log {
206179

207180
thrd = std::thread([this]() {
208181
while (true) {
209-
std::unique_lock<std::mutex> lock_wrk(mtx_wrk);
210-
cv.wait(lock_wrk, [this]() { return head != tail; });
182+
{
183+
std::unique_lock<std::mutex> lock(mtx);
184+
cv.wait(lock, [this]() { return head != tail; });
211185

212-
auto & entry = entries[head];
186+
cur = entries[head];
213187

214-
if (entry.is_end) {
188+
head = (head + 1) % entries.size();
189+
}
190+
191+
if (cur.is_end) {
215192
break;
216193
}
217194

218-
entry.print(stdout);
195+
cur.print(stdout);
219196

220197
if (file) {
221-
entry.print(file);
198+
cur.print(file);
222199
}
223-
224-
head = (head + 1) % entries.size();
225200
}
226201
});
227202
}
228203

229204
void pause() {
230-
std::lock_guard<std::mutex> lock_inp(mtx_inp);
205+
{
206+
std::lock_guard<std::mutex> lock(mtx);
231207

232-
if (!running) {
233-
return;
234-
}
208+
if (!running) {
209+
return;
210+
}
235211

236-
{
237-
std::lock_guard<std::mutex> lock_wrk(mtx_wrk);
212+
running = false;
238213

239214
auto & entry = entries[tail];
240215

@@ -246,13 +221,10 @@ struct gpt_log {
246221
}
247222

248223
thrd.join();
249-
250-
running = false;
251224
}
252225

253226
void set_file(const char * path) {
254-
std::lock_guard<std::mutex> lock_inp(mtx_inp);
255-
std::lock_guard<std::mutex> lock_wrk(mtx_wrk);
227+
pause();
256228

257229
if (file) {
258230
fclose(file);
@@ -263,11 +235,12 @@ struct gpt_log {
263235
} else {
264236
file = nullptr;
265237
}
238+
239+
resume();
266240
}
267241

268242
void set_timestamps(bool timestamps) {
269-
std::lock_guard<std::mutex> lock_inp(mtx_inp);
270-
std::lock_guard<std::mutex> lock_wrk(mtx_wrk);
243+
std::lock_guard<std::mutex> lock(mtx);
271244

272245
this->timestamps = timestamps;
273246
}

common/log.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ void gpt_log_free (struct gpt_log * log);
2727
LOG_ATTRIBUTE_FORMAT(4, 5)
2828
void gpt_log_add(struct gpt_log * log, enum ggml_log_level level, int verbosity, const char * fmt, ...);
2929

30-
void gpt_log_set_file (struct gpt_log * log, const char * file);
30+
void gpt_log_set_file (struct gpt_log * log, const char * file); // not thread-safe
3131
void gpt_log_set_timestamps(struct gpt_log * log, bool timestamps);
3232

3333
#define LOG_TMPL(level, verbosity, ...) \

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ llama_test(test-tokenizer-1-spm NAME test-tokenizer-1-llama-spm ARGS ${CMAKE_CU
108108
#llama_test(test-tokenizer-1-spm NAME test-tokenizer-1-baichuan ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
109109

110110
# llama_target_and_test(test-double-float.cpp) # SLOW
111+
llama_target_and_test(test-log.cpp)
111112
llama_target_and_test(test-arg-parser.cpp)
112113
llama_target_and_test(test-quantize-fns.cpp)
113114
llama_target_and_test(test-quantize-perf.cpp)

tests/test-log.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include "log.h"
2+
3+
#include <thread>
4+
5+
int main() {
6+
const int n_thread = 8;
7+
const int n_msg = 1000;
8+
9+
std::thread threads[n_thread];
10+
for (int i = 0; i < n_thread; i++) {
11+
threads[i] = std::thread([i]() {
12+
for (int j = 0; j < n_msg; j++) {
13+
const int log_type = std::rand() % 4;
14+
15+
switch (log_type) {
16+
case 0: LOG_INF("Thread %d: %d\n", i, j); break;
17+
case 1: LOG_WRN("Thread %d: %d\n", i, j); break;
18+
case 2: LOG_ERR("Thread %d: %d\n", i, j); break;
19+
case 3: LOG_DBG("Thread %d: %d\n", i, j); break;
20+
default:
21+
break;
22+
}
23+
24+
if (rand () % 10 < 5) {
25+
gpt_log_set_timestamps(gpt_log_main(), rand() % 2);
26+
}
27+
}
28+
});
29+
}
30+
31+
for (int i = 0; i < n_thread; i++) {
32+
threads[i].join();
33+
}
34+
35+
gpt_log_pause(gpt_log_main());
36+
37+
return 0;
38+
}

0 commit comments

Comments
 (0)