Skip to content

Commit 35670e7

Browse files
committed
train-text-from-scratch: rename ff tensors
This commit renames the feed-forward tensors w1, w2 and w3 to ffn_gate, ffn_down and ffn_up respectively. The motivation for this change is to make it easier to understand the purpose of the tensors. This also seems to be inline with the names used in the llama_layer struct in llama.cpp Signed-off-by: Daniel Bevenius <[email protected]>
1 parent fa2c0d5 commit 35670e7

File tree

1 file changed

+27
-27
lines changed

1 file changed

+27
-27
lines changed

examples/train-text-from-scratch/train-text-from-scratch.cpp

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ struct my_llama_layer {
5050
struct ggml_tensor * ffn_norm;
5151

5252
// ff
53-
struct ggml_tensor * w1;
54-
struct ggml_tensor * w2;
55-
struct ggml_tensor * w3;
53+
struct ggml_tensor * ffn_gate; // w1
54+
struct ggml_tensor * ffn_down; // w2
55+
struct ggml_tensor * ffn_up; // w3
5656
};
5757

5858
struct my_llama_model {
@@ -140,9 +140,9 @@ static void set_param_model(struct my_llama_model * model) {
140140
ggml_set_param(ctx, layer.wv);
141141
ggml_set_param(ctx, layer.wo);
142142
ggml_set_param(ctx, layer.ffn_norm);
143-
ggml_set_param(ctx, layer.w1);
144-
ggml_set_param(ctx, layer.w2);
145-
ggml_set_param(ctx, layer.w3);
143+
ggml_set_param(ctx, layer.ffn_gate);
144+
ggml_set_param(ctx, layer.ffn_down);
145+
ggml_set_param(ctx, layer.ffn_up);
146146
}
147147
}
148148

@@ -198,9 +198,9 @@ static void init_model(struct my_llama_model * model) {
198198

199199
layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
200200

201-
layer.w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
202-
layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
203-
layer.w3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
201+
layer.ffn_gate = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
202+
layer.ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
203+
layer.ffn_up = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
204204

205205
ggml_set_name(layer.attention_norm, tni(LLM_TENSOR_ATTN_NORM, i));
206206

@@ -211,9 +211,9 @@ static void init_model(struct my_llama_model * model) {
211211

212212
ggml_set_name(layer.ffn_norm, tni(LLM_TENSOR_FFN_NORM, i));
213213

214-
ggml_set_name(layer.w1, tni(LLM_TENSOR_FFN_GATE, i));
215-
ggml_set_name(layer.w2, tni(LLM_TENSOR_FFN_DOWN, i));
216-
ggml_set_name(layer.w3, tni(LLM_TENSOR_FFN_UP, i));
214+
ggml_set_name(layer.ffn_gate, tni(LLM_TENSOR_FFN_GATE, i));
215+
ggml_set_name(layer.ffn_down, tni(LLM_TENSOR_FFN_DOWN, i));
216+
ggml_set_name(layer.ffn_up, tni(LLM_TENSOR_FFN_UP, i));
217217
}
218218

219219
set_param_model(model);
@@ -244,9 +244,9 @@ static void randomize_model(struct my_llama_model * model, int seed, float mean,
244244

245245
randomize_tensor_normal(layer.ffn_norm, rnd);
246246

247-
randomize_tensor_normal(layer.w1, rnd);
248-
randomize_tensor_normal(layer.w2, rnd);
249-
randomize_tensor_normal(layer.w3, rnd);
247+
randomize_tensor_normal(layer.ffn_gate, rnd);
248+
randomize_tensor_normal(layer.ffn_down, rnd);
249+
randomize_tensor_normal(layer.ffn_up, rnd);
250250
}
251251

252252
free_random_normal_distribution(rnd);
@@ -356,11 +356,11 @@ static struct ggml_tensor * llama_build_train_graphs(
356356
struct ggml_tensor * t22 = ggml_rms_norm (ctx, t21, f_norm_rms_eps); set_name(t22, "t22"); assert_shape_2d(t22, n_embd, N*n_batch);
357357
struct ggml_tensor * t23 = ggml_repeat (ctx, layer.ffn_norm, t22); set_name(t23, "t23"); assert_shape_2d(t23, n_embd, N*n_batch);
358358
struct ggml_tensor * t24 = ggml_mul (ctx, t23, t22); set_name(t24, "t24"); assert_shape_2d(t24, n_embd, N*n_batch);
359-
struct ggml_tensor * t25 = ggml_mul_mat (ctx, layer.w3, t24); set_name(t25, "t25"); assert_shape_2d(t25, n_ff, N*n_batch);
360-
struct ggml_tensor * t26 = ggml_mul_mat (ctx, layer.w1, t24); set_name(t26, "t26"); assert_shape_2d(t26, n_ff, N*n_batch);
359+
struct ggml_tensor * t25 = ggml_mul_mat (ctx, layer.ffn_up, t24); set_name(t25, "t25"); assert_shape_2d(t25, n_ff, N*n_batch);
360+
struct ggml_tensor * t26 = ggml_mul_mat (ctx, layer.ffn_gate, t24); set_name(t26, "t26"); assert_shape_2d(t26, n_ff, N*n_batch);
361361
struct ggml_tensor * t27 = ggml_silu (ctx, t26); set_name(t27, "t27"); assert_shape_2d(t27, n_ff, N*n_batch);
362362
struct ggml_tensor * t28 = ggml_mul (ctx, t27, t25); set_name(t28, "t28"); assert_shape_2d(t28, n_ff, N*n_batch);
363-
struct ggml_tensor * t29 = ggml_mul_mat (ctx, layer.w2, t28); set_name(t29, "t29"); assert_shape_2d(t29, n_embd, N*n_batch);
363+
struct ggml_tensor * t29 = ggml_mul_mat (ctx, layer.ffn_down, t28); set_name(t29, "t29"); assert_shape_2d(t29, n_embd, N*n_batch);
364364
struct ggml_tensor * t30 = ggml_add (ctx, t29, t21); set_name(t30, "t30"); assert_shape_2d(t30, n_embd, N*n_batch);
365365
cur = t30;
366366
checkpoints.push_back(cur);
@@ -521,9 +521,9 @@ static void load_llama_model_gguf(struct gguf_context * fctx, struct ggml_contex
521521
copy_tensor_by_name(layer.wv, f_ggml_ctx, tni(LLM_TENSOR_ATTN_V, i));
522522
copy_tensor_by_name(layer.wo, f_ggml_ctx, tni(LLM_TENSOR_ATTN_OUT, i));
523523
copy_tensor_by_name(layer.ffn_norm, f_ggml_ctx, tni(LLM_TENSOR_FFN_NORM, i));
524-
copy_tensor_by_name(layer.w1, f_ggml_ctx, tni(LLM_TENSOR_FFN_GATE, i));
525-
copy_tensor_by_name(layer.w2, f_ggml_ctx, tni(LLM_TENSOR_FFN_DOWN, i));
526-
copy_tensor_by_name(layer.w3, f_ggml_ctx, tni(LLM_TENSOR_FFN_UP, i));
524+
copy_tensor_by_name(layer.ffn_gate, f_ggml_ctx, tni(LLM_TENSOR_FFN_GATE, i));
525+
copy_tensor_by_name(layer.ffn_down, f_ggml_ctx, tni(LLM_TENSOR_FFN_DOWN, i));
526+
copy_tensor_by_name(layer.ffn_up, f_ggml_ctx, tni(LLM_TENSOR_FFN_UP, i));
527527
}
528528
}
529529

@@ -664,9 +664,9 @@ static void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vo
664664
gguf_add_tensor(fctx, layer.wv);
665665
gguf_add_tensor(fctx, layer.wo);
666666
gguf_add_tensor(fctx, layer.ffn_norm);
667-
gguf_add_tensor(fctx, layer.w1);
668-
gguf_add_tensor(fctx, layer.w2);
669-
gguf_add_tensor(fctx, layer.w3);
667+
gguf_add_tensor(fctx, layer.ffn_gate);
668+
gguf_add_tensor(fctx, layer.ffn_down);
669+
gguf_add_tensor(fctx, layer.ffn_up);
670670
}
671671
}
672672

@@ -915,9 +915,9 @@ static int64_t get_parameter_count(struct my_llama_model* model) {
915915
nx += ggml_nelements(layer.wv);
916916
nx += ggml_nelements(layer.wo);
917917
nx += ggml_nelements(layer.ffn_norm);
918-
nx += ggml_nelements(layer.w1);
919-
nx += ggml_nelements(layer.w2);
920-
nx += ggml_nelements(layer.w3);
918+
nx += ggml_nelements(layer.ffn_gate);
919+
nx += ggml_nelements(layer.ffn_down);
920+
nx += ggml_nelements(layer.ffn_up);
921921
}
922922
return nx;
923923
}

0 commit comments

Comments
 (0)