Skip to content

Commit e3c52bd

Browse files
committed
ggml : pass eps to ggml_norm
1 parent d561b7f commit e3c52bd

File tree

4 files changed

+49
-40
lines changed

4 files changed

+49
-40
lines changed

ggml-metal.m

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,8 @@ void ggml_metal_graph_compute(
938938
} break;
939939
case GGML_OP_NORM:
940940
{
941-
const float eps = 1e-5f;
941+
float eps;
942+
memcpy(&eps, dst->op_params, sizeof(float));
942943

943944
const int nth = 256;
944945

ggml.c

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5789,6 +5789,7 @@ struct ggml_tensor * ggml_silu_back(
57895789
static struct ggml_tensor * ggml_norm_impl(
57905790
struct ggml_context * ctx,
57915791
struct ggml_tensor * a,
5792+
float eps,
57925793
bool inplace) {
57935794
bool is_node = false;
57945795

@@ -5799,7 +5800,7 @@ static struct ggml_tensor * ggml_norm_impl(
57995800

58005801
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
58015802

5802-
// TODO: maybe store epsilon here?
5803+
ggml_set_op_params(result, &eps, sizeof(eps));
58035804

58045805
result->op = GGML_OP_NORM;
58055806
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5810,14 +5811,16 @@ static struct ggml_tensor * ggml_norm_impl(
58105811

58115812
struct ggml_tensor * ggml_norm(
58125813
struct ggml_context * ctx,
5813-
struct ggml_tensor * a) {
5814-
return ggml_norm_impl(ctx, a, false);
5814+
struct ggml_tensor * a,
5815+
float eps) {
5816+
return ggml_norm_impl(ctx, a, eps, false);
58155817
}
58165818

58175819
struct ggml_tensor * ggml_norm_inplace(
58185820
struct ggml_context * ctx,
5819-
struct ggml_tensor * a) {
5820-
return ggml_norm_impl(ctx, a, true);
5821+
struct ggml_tensor * a,
5822+
float eps) {
5823+
return ggml_norm_impl(ctx, a, eps, true);
58215824
}
58225825

58235826
// ggml_rms_norm
@@ -10619,7 +10622,8 @@ static void ggml_compute_forward_norm_f32(
1061910622

1062010623
GGML_TENSOR_UNARY_OP_LOCALS;
1062110624

10622-
const float eps = 1e-5f; // TODO: make this a parameter
10625+
float eps;
10626+
memcpy(&eps, dst->op_params, sizeof(float));
1062310627

1062410628
// TODO: optimize
1062510629
for (int64_t i03 = 0; i03 < ne03; i03++) {

ggml.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -909,14 +909,15 @@ extern "C" {
909909
struct ggml_tensor * b);
910910

911911
// normalize along rows
912-
// TODO: eps is hardcoded to 1e-5 for now
913912
GGML_API struct ggml_tensor * ggml_norm(
914913
struct ggml_context * ctx,
915-
struct ggml_tensor * a);
914+
struct ggml_tensor * a,
915+
float eps);
916916

917917
GGML_API struct ggml_tensor * ggml_norm_inplace(
918918
struct ggml_context * ctx,
919-
struct ggml_tensor * a);
919+
struct ggml_tensor * a,
920+
float eps);
920921

921922
GGML_API struct ggml_tensor * ggml_rms_norm(
922923
struct ggml_context * ctx,

llama.cpp

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,7 @@ struct llama_hparams {
830830
uint32_t n_rot = 64;
831831
uint32_t n_ff = 11008;
832832

833+
float f_norm_eps = 1e-5;
833834
float f_norm_rms_eps = 1e-5;
834835

835836
float rope_freq_base = 10000.0f;
@@ -1557,6 +1558,7 @@ static void llm_load_hparams(
15571558
} break;
15581559
case LLM_ARCH_FALCON:
15591560
{
1561+
GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
15601562
} break;
15611563
default: (void)0;
15621564
};
@@ -1672,28 +1674,29 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
16721674
const auto & vocab = model.vocab;
16731675

16741676
// hparams
1675-
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
1676-
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
1677-
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix
1678-
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
1679-
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
1680-
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx);
1681-
LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
1682-
LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head);
1683-
LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
1684-
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
1685-
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
1686-
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
1687-
LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
1688-
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
1689-
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
1690-
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
1691-
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
1692-
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
1693-
LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml.n_elements*1e-9);
1677+
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
1678+
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
1679+
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix
1680+
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
1681+
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
1682+
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx);
1683+
LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
1684+
LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head);
1685+
LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
1686+
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
1687+
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
1688+
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
1689+
LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps);
1690+
LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
1691+
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
1692+
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
1693+
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
1694+
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
1695+
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
1696+
LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml.n_elements*1e-9);
16941697

16951698
// general kv
1696-
LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str());
1699+
LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str());
16971700

16981701
// special tokens
16991702
if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); }
@@ -1899,8 +1902,7 @@ static void llm_load_tensors(
18991902
mmapped_size - vram_weights; // weights in VRAM not in memory
19001903

19011904
// this is the memory required by one llama_state
1902-
const size_t mem_required_state =
1903-
scale*hparams.kv_size();
1905+
const size_t mem_required_state = scale*hparams.kv_size();
19041906

19051907
LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
19061908
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
@@ -2383,6 +2385,10 @@ static struct ggml_cgraph * llm_build_falcon(
23832385

23842386
GGML_ASSERT(n_embd_head == hparams.n_rot);
23852387

2388+
const float freq_base = hparams.rope_freq_base;
2389+
const float freq_scale = hparams.rope_freq_scale;
2390+
const float norm_eps = hparams.f_norm_eps;
2391+
23862392
auto & buf_compute = lctx.buf_compute;
23872393

23882394
struct ggml_init_params params = {
@@ -2436,7 +2442,7 @@ static struct ggml_cgraph * llm_build_falcon(
24362442

24372443
// self-attention
24382444
{
2439-
attn_norm = ggml_norm(ctx0, inpL);
2445+
attn_norm = ggml_norm(ctx0, inpL, norm_eps);
24402446

24412447
attn_norm = ggml_add(ctx0,
24422448
ggml_mul(ctx0,
@@ -2445,7 +2451,7 @@ static struct ggml_cgraph * llm_build_falcon(
24452451
ggml_repeat(ctx0, model.layers[il].attn_norm_b, attn_norm));
24462452

24472453
if (model.layers[il].attn_norm_2) { // Falcon-40B
2448-
cur = ggml_norm(ctx0, inpL);
2454+
cur = ggml_norm(ctx0, inpL, norm_eps);
24492455

24502456
cur = ggml_add(ctx0,
24512457
ggml_mul(ctx0,
@@ -2490,8 +2496,8 @@ static struct ggml_cgraph * llm_build_falcon(
24902496
wsize * n_embd_head * (n_head + n_head_kv));
24912497

24922498
// using mode = 2 for neox mode
2493-
Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_embd_head, 2, 0);
2494-
Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_embd_head, 2, 0);
2499+
Qcur = ggml_rope_custom_inplace(ctx0, Qcur, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
2500+
Kcur = ggml_rope_custom_inplace(ctx0, Kcur, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
24952501

24962502
// store key and value to memory
24972503
{
@@ -2522,8 +2528,6 @@ static struct ggml_cgraph * llm_build_falcon(
25222528

25232529
// K * Q
25242530

2525-
// K = ggml_cont(ctx0, ggml_repeat2(ctx0, K, repeat_dummy));
2526-
25272531
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
25282532
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
25292533

@@ -2549,7 +2553,6 @@ static struct ggml_cgraph * llm_build_falcon(
25492553
n_embd_head, n_head_kv, n_past + N),
25502554
0, 2, 1, 3);
25512555

2552-
// V = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_repeat2(ctx0, V, repeat_dummy)));
25532556
V = ggml_cont(ctx0, ggml_transpose(ctx0, V));
25542557

25552558
// KQV = transpose(V) * KQ_soft_max
@@ -2589,7 +2592,7 @@ static struct ggml_cgraph * llm_build_falcon(
25892592

25902593
// norm
25912594
{
2592-
cur = ggml_norm(ctx0, inpL);
2595+
cur = ggml_norm(ctx0, inpL, norm_eps);
25932596

25942597
cur = ggml_add(ctx0,
25952598
ggml_mul(ctx0,

0 commit comments

Comments
 (0)