Skip to content

Commit 8d7cfb4

Browse files
committed
fix baby llama, test-grad0
1 parent 24e53a1 commit 8d7cfb4

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

examples/baby-llama/baby-llama.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#pragma warning(disable: 4244 4267) // possible loss of data
99
#endif
1010

11+
static const float rms_norm_eps = 1e-6f;
12+
1113
float frand() {
1214
return (float)rand()/(float)RAND_MAX;
1315
}
@@ -562,7 +564,7 @@ struct ggml_tensor * forward(
562564
// norm
563565
{
564566
// cur shape [n_embd,N,1,1]
565-
cur = ggml_rms_norm(ctx0, inpL);
567+
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
566568

567569
// cur = attention_norm*cur
568570
cur = ggml_mul(ctx0,
@@ -685,7 +687,7 @@ struct ggml_tensor * forward(
685687
// norm
686688
{
687689
// cur shape [n_embd,N,1,1]
688-
cur = ggml_rms_norm(ctx0, inpFF);
690+
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
689691

690692
// cur = ffn_norm*cur
691693
// cur shape [n_embd,N,1,1]
@@ -729,7 +731,7 @@ struct ggml_tensor * forward(
729731
{
730732

731733
// inpL shape [n_embd,N,1,1]
732-
inpL = ggml_rms_norm(ctx0, inpL);
734+
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
733735

734736
// inpL = norm*inpL
735737
// inpL shape [n_embd,N,1,1]
@@ -817,7 +819,7 @@ struct ggml_tensor * forward_batch(
817819
// norm
818820
{
819821
// cur shape [n_embd,N*n_batch,1,1]
820-
cur = ggml_rms_norm(ctx0, inpL);
822+
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
821823
assert_shape_2d(cur, n_embd, N*n_batch);
822824

823825
// cur = attention_norm*cur
@@ -981,7 +983,7 @@ struct ggml_tensor * forward_batch(
981983
// norm
982984
{
983985
// cur shape [n_embd,N*n_batch,1,1]
984-
cur = ggml_rms_norm(ctx0, inpFF);
986+
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
985987
assert_shape_2d(cur, n_embd, N*n_batch);
986988

987989
// cur = ffn_norm*cur
@@ -1034,7 +1036,7 @@ struct ggml_tensor * forward_batch(
10341036
{
10351037

10361038
// inpL shape [n_embd,N*n_batch,1,1]
1037-
inpL = ggml_rms_norm(ctx0, inpL);
1039+
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
10381040
assert_shape_2d(inpL, n_embd, N*n_batch);
10391041

10401042
// inpL = norm*inpL
@@ -1104,7 +1106,7 @@ struct ggml_tensor * forward_lora(
11041106
// norm
11051107
{
11061108
// cur shape [n_embd,N,1,1]
1107-
cur = ggml_rms_norm(ctx0, inpL);
1109+
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
11081110

11091111
// cur = attention_norm*cur
11101112
cur = ggml_mul(ctx0,
@@ -1251,7 +1253,7 @@ struct ggml_tensor * forward_lora(
12511253
// norm
12521254
{
12531255
// cur shape [n_embd,N,1,1]
1254-
cur = ggml_rms_norm(ctx0, inpFF);
1256+
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
12551257

12561258
// cur = ffn_norm*cur
12571259
// cur shape [n_embd,N,1,1]
@@ -1295,7 +1297,7 @@ struct ggml_tensor * forward_lora(
12951297
{
12961298

12971299
// inpL shape [n_embd,N,1,1]
1298-
inpL = ggml_rms_norm(ctx0, inpL);
1300+
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
12991301

13001302
// inpL = norm*inpL
13011303
// inpL shape [n_embd,N,1,1]

tests/test-grad0.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ int main(int argc, const char ** argv) {
850850
ggml_set_param(ctx0, x[i]);
851851
}
852852

853-
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0]));
853+
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f));
854854

855855
check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY);
856856
}

0 commit comments

Comments
 (0)