Skip to content

Commit a5e8c1d

Browse files
authored
train-text-from-scratch : fix assert failure in ggml-alloc (#3618)
1 parent e74c705 commit a5e8c1d

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

examples/train-text-from-scratch/train-text-from-scratch.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,14 @@ static void init_model(struct my_llama_model * model) {
253253
set_param_model(model);
254254

255255
// measure data size
256-
struct ggml_allocr * alloc = NULL;
257-
alloc = ggml_allocr_new_measure(tensor_alignment);
258-
alloc_model(alloc, model);
256+
size_t size = 0;
257+
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
258+
size += GGML_PAD(ggml_nbytes(t), tensor_alignment);
259+
}
259260

260261
// allocate data
261-
model->data.resize(ggml_allocr_max_size(alloc) + tensor_alignment);
262-
ggml_allocr_free(alloc);
262+
struct ggml_allocr * alloc = NULL;
263+
model->data.resize(size + tensor_alignment);
263264
alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment);
264265
alloc_model(alloc, model);
265266
ggml_allocr_free(alloc);
@@ -1094,11 +1095,9 @@ int main(int argc, char ** argv) {
10941095
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
10951096

10961097
// measure required memory for input tensors
1097-
alloc = ggml_allocr_new_measure(tensor_alignment);
1098-
ggml_allocr_alloc(alloc, tokens_input);
1099-
ggml_allocr_alloc(alloc, target_probs);
1100-
size_t max_input_size = ggml_allocr_max_size(alloc) + tensor_alignment;
1101-
ggml_allocr_free(alloc);
1098+
size_t max_input_size = GGML_PAD(ggml_nbytes(tokens_input), tensor_alignment) +
1099+
GGML_PAD(ggml_nbytes(target_probs), tensor_alignment) +
1100+
tensor_alignment;
11021101
printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
11031102

11041103
// allocate input tensors

0 commit comments

Comments
 (0)