@@ -50,9 +50,9 @@ struct my_llama_layer {
50
50
struct ggml_tensor * ffn_norm;
51
51
52
52
// ff
53
- struct ggml_tensor * w1;
54
- struct ggml_tensor * w2;
55
- struct ggml_tensor * w3;
53
+ struct ggml_tensor * ffn_gate; // w1
54
+ struct ggml_tensor * ffn_down; // w2
55
+ struct ggml_tensor * ffn_up; // w3
56
56
};
57
57
58
58
struct my_llama_model {
@@ -140,9 +140,9 @@ static void set_param_model(struct my_llama_model * model) {
140
140
ggml_set_param (ctx, layer.wv );
141
141
ggml_set_param (ctx, layer.wo );
142
142
ggml_set_param (ctx, layer.ffn_norm );
143
- ggml_set_param (ctx, layer.w1 );
144
- ggml_set_param (ctx, layer.w2 );
145
- ggml_set_param (ctx, layer.w3 );
143
+ ggml_set_param (ctx, layer.ffn_gate );
144
+ ggml_set_param (ctx, layer.ffn_down );
145
+ ggml_set_param (ctx, layer.ffn_up );
146
146
}
147
147
}
148
148
@@ -198,9 +198,9 @@ static void init_model(struct my_llama_model * model) {
198
198
199
199
layer.ffn_norm = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, n_embd);
200
200
201
- layer.w1 = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, n_embd, n_ff);
202
- layer.w2 = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, n_ff, n_embd);
203
- layer.w3 = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, n_embd, n_ff);
201
+ layer.ffn_gate = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, n_embd, n_ff);
202
+ layer.ffn_down = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, n_ff, n_embd);
203
+ layer.ffn_up = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, n_embd, n_ff);
204
204
205
205
ggml_set_name (layer.attention_norm , tni (LLM_TENSOR_ATTN_NORM, i));
206
206
@@ -211,9 +211,9 @@ static void init_model(struct my_llama_model * model) {
211
211
212
212
ggml_set_name (layer.ffn_norm , tni (LLM_TENSOR_FFN_NORM, i));
213
213
214
- ggml_set_name (layer.w1 , tni (LLM_TENSOR_FFN_GATE, i));
215
- ggml_set_name (layer.w2 , tni (LLM_TENSOR_FFN_DOWN, i));
216
- ggml_set_name (layer.w3 , tni (LLM_TENSOR_FFN_UP, i));
214
+ ggml_set_name (layer.ffn_gate , tni (LLM_TENSOR_FFN_GATE, i));
215
+ ggml_set_name (layer.ffn_down , tni (LLM_TENSOR_FFN_DOWN, i));
216
+ ggml_set_name (layer.ffn_up , tni (LLM_TENSOR_FFN_UP, i));
217
217
}
218
218
219
219
set_param_model (model);
@@ -244,9 +244,9 @@ static void randomize_model(struct my_llama_model * model, int seed, float mean,
244
244
245
245
randomize_tensor_normal (layer.ffn_norm , rnd);
246
246
247
- randomize_tensor_normal (layer.w1 , rnd);
248
- randomize_tensor_normal (layer.w2 , rnd);
249
- randomize_tensor_normal (layer.w3 , rnd);
247
+ randomize_tensor_normal (layer.ffn_gate , rnd);
248
+ randomize_tensor_normal (layer.ffn_down , rnd);
249
+ randomize_tensor_normal (layer.ffn_up , rnd);
250
250
}
251
251
252
252
free_random_normal_distribution (rnd);
@@ -356,11 +356,11 @@ static struct ggml_tensor * llama_build_train_graphs(
356
356
struct ggml_tensor * t22 = ggml_rms_norm (ctx, t21, f_norm_rms_eps); set_name (t22, " t22" ); assert_shape_2d (t22, n_embd, N*n_batch);
357
357
struct ggml_tensor * t23 = ggml_repeat (ctx, layer.ffn_norm , t22); set_name (t23, " t23" ); assert_shape_2d (t23, n_embd, N*n_batch);
358
358
struct ggml_tensor * t24 = ggml_mul (ctx, t23, t22); set_name (t24, " t24" ); assert_shape_2d (t24, n_embd, N*n_batch);
359
- struct ggml_tensor * t25 = ggml_mul_mat (ctx, layer.w3 , t24); set_name (t25, " t25" ); assert_shape_2d (t25, n_ff, N*n_batch);
360
- struct ggml_tensor * t26 = ggml_mul_mat (ctx, layer.w1 , t24); set_name (t26, " t26" ); assert_shape_2d (t26, n_ff, N*n_batch);
359
+ struct ggml_tensor * t25 = ggml_mul_mat (ctx, layer.ffn_up , t24); set_name (t25, " t25" ); assert_shape_2d (t25, n_ff, N*n_batch);
360
+ struct ggml_tensor * t26 = ggml_mul_mat (ctx, layer.ffn_gate , t24); set_name (t26, " t26" ); assert_shape_2d (t26, n_ff, N*n_batch);
361
361
struct ggml_tensor * t27 = ggml_silu (ctx, t26); set_name (t27, " t27" ); assert_shape_2d (t27, n_ff, N*n_batch);
362
362
struct ggml_tensor * t28 = ggml_mul (ctx, t27, t25); set_name (t28, " t28" ); assert_shape_2d (t28, n_ff, N*n_batch);
363
- struct ggml_tensor * t29 = ggml_mul_mat (ctx, layer.w2 , t28); set_name (t29, " t29" ); assert_shape_2d (t29, n_embd, N*n_batch);
363
+ struct ggml_tensor * t29 = ggml_mul_mat (ctx, layer.ffn_down , t28); set_name (t29, " t29" ); assert_shape_2d (t29, n_embd, N*n_batch);
364
364
struct ggml_tensor * t30 = ggml_add (ctx, t29, t21); set_name (t30, " t30" ); assert_shape_2d (t30, n_embd, N*n_batch);
365
365
cur = t30;
366
366
checkpoints.push_back (cur);
@@ -521,9 +521,9 @@ static void load_llama_model_gguf(struct gguf_context * fctx, struct ggml_contex
521
521
copy_tensor_by_name (layer.wv , f_ggml_ctx, tni (LLM_TENSOR_ATTN_V, i));
522
522
copy_tensor_by_name (layer.wo , f_ggml_ctx, tni (LLM_TENSOR_ATTN_OUT, i));
523
523
copy_tensor_by_name (layer.ffn_norm , f_ggml_ctx, tni (LLM_TENSOR_FFN_NORM, i));
524
- copy_tensor_by_name (layer.w1 , f_ggml_ctx, tni (LLM_TENSOR_FFN_GATE, i));
525
- copy_tensor_by_name (layer.w2 , f_ggml_ctx, tni (LLM_TENSOR_FFN_DOWN, i));
526
- copy_tensor_by_name (layer.w3 , f_ggml_ctx, tni (LLM_TENSOR_FFN_UP, i));
524
+ copy_tensor_by_name (layer.ffn_gate , f_ggml_ctx, tni (LLM_TENSOR_FFN_GATE, i));
525
+ copy_tensor_by_name (layer.ffn_down , f_ggml_ctx, tni (LLM_TENSOR_FFN_DOWN, i));
526
+ copy_tensor_by_name (layer.ffn_up , f_ggml_ctx, tni (LLM_TENSOR_FFN_UP, i));
527
527
}
528
528
}
529
529
@@ -664,9 +664,9 @@ static void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vo
664
664
gguf_add_tensor (fctx, layer.wv );
665
665
gguf_add_tensor (fctx, layer.wo );
666
666
gguf_add_tensor (fctx, layer.ffn_norm );
667
- gguf_add_tensor (fctx, layer.w1 );
668
- gguf_add_tensor (fctx, layer.w2 );
669
- gguf_add_tensor (fctx, layer.w3 );
667
+ gguf_add_tensor (fctx, layer.ffn_gate );
668
+ gguf_add_tensor (fctx, layer.ffn_down );
669
+ gguf_add_tensor (fctx, layer.ffn_up );
670
670
}
671
671
}
672
672
@@ -915,9 +915,9 @@ static int64_t get_parameter_count(struct my_llama_model* model) {
915
915
nx += ggml_nelements (layer.wv );
916
916
nx += ggml_nelements (layer.wo );
917
917
nx += ggml_nelements (layer.ffn_norm );
918
- nx += ggml_nelements (layer.w1 );
919
- nx += ggml_nelements (layer.w2 );
920
- nx += ggml_nelements (layer.w3 );
918
+ nx += ggml_nelements (layer.ffn_gate );
919
+ nx += ggml_nelements (layer.ffn_down );
920
+ nx += ggml_nelements (layer.ffn_up );
921
921
}
922
922
return nx;
923
923
}
0 commit comments