Skip to content

Commit b34ab74

Browse files
committed
falcon : copy-paste self-attention from LLaMA
1 parent af4bbcc commit b34ab74

File tree

1 file changed

+52
-65
lines changed

1 file changed

+52
-65
lines changed

llama.cpp

Lines changed: 52 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2201,10 +2201,7 @@ static struct ggml_cgraph * llm_build_llama(
22012201
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
22022202
}
22032203

2204-
struct ggml_tensor * Q =
2205-
ggml_permute(ctx0,
2206-
Qcur,
2207-
0, 2, 1, 3);
2204+
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
22082205
offload_func_kq(Q);
22092206
ggml_set_name(Q, "Q");
22102207

@@ -2381,7 +2378,7 @@ static struct ggml_cgraph * llm_build_falcon(
23812378
const int64_t n_head = hparams.n_head;
23822379
const int64_t n_head_kv = hparams.n_head_kv;
23832380
const int64_t n_embd_head = hparams.n_embd_head();
2384-
//const int64_t n_embd_gqa = hparams.n_embd_gqa();
2381+
const int64_t n_embd_gqa = hparams.n_embd_gqa();
23852382

23862383
GGML_ASSERT(n_embd_head == hparams.n_rot);
23872384

@@ -2441,6 +2438,7 @@ static struct ggml_cgraph * llm_build_falcon(
24412438
struct ggml_tensor * attn_norm;
24422439

24432440
// self-attention
2441+
// TODO: refactor into common function (shared with LLaMA)
24442442
{
24452443
attn_norm = ggml_norm(ctx0, inpL, norm_eps);
24462444

@@ -2473,115 +2471,104 @@ static struct ggml_cgraph * llm_build_falcon(
24732471

24742472
const size_t wsize = ggml_type_size(cur->type);
24752473

2476-
struct ggml_tensor * Qcur = ggml_view_3d(
2474+
struct ggml_tensor * tmpq = ggml_view_3d(
24772475
ctx0, cur, n_embd_head, n_head, N,
24782476
wsize * n_embd_head,
24792477
wsize * n_embd_head * (n_head + 2 * n_head_kv),
24802478
0);
24812479

2482-
struct ggml_tensor * Kcur = ggml_view_3d(
2480+
struct ggml_tensor * tmpk = ggml_view_3d(
24832481
ctx0, cur, n_embd_head, n_head_kv, N,
24842482
wsize * n_embd_head,
24852483
wsize * n_embd_head * (n_head + 2 * n_head_kv),
2486-
wsize * n_embd_head * n_head);
2484+
wsize * n_embd_head * n_head);
24872485

2488-
struct ggml_tensor * Vcur = ggml_view_3d(
2486+
struct ggml_tensor * tmpv = ggml_view_3d(
24892487
ctx0, cur, n_embd_head, n_head_kv, N,
24902488
wsize * n_embd_head,
24912489
wsize * n_embd_head * (n_head + 2 * n_head_kv),
24922490
wsize * n_embd_head * (n_head + n_head_kv));
24932491

24942492
// using mode = 2 for neox mode
2495-
Qcur = ggml_rope_custom_inplace(ctx0, Qcur, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
2496-
Kcur = ggml_rope_custom_inplace(ctx0, Kcur, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
2493+
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
2494+
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
24972495

2498-
// store key and value to memory
24992496
{
2500-
struct ggml_tensor* k = ggml_view_1d(
2501-
ctx0, kv_self.k, N * n_head_kv * n_embd_head,
2502-
(ggml_element_size(kv_self.k) * n_head_kv * n_embd_head) *
2503-
(il * n_ctx + n_past));
2504-
struct ggml_tensor* v = ggml_view_1d(
2505-
ctx0, kv_self.v, N * n_head_kv * n_embd_head,
2506-
(ggml_element_size(kv_self.v) * n_head_kv * n_embd_head) *
2507-
(il * n_ctx + n_past));
2497+
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N));
2498+
ggml_set_name(Vcur, "Vcur");
2499+
2500+
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
2501+
ggml_set_name(k, "k");
2502+
2503+
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
2504+
( n_ctx)*ggml_element_size(kv_self.v),
2505+
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
25082506

25092507
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
25102508
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
25112509
}
25122510

2513-
struct ggml_tensor * K = ggml_permute(
2514-
ctx0,
2515-
ggml_reshape_3d(
2516-
ctx0,
2517-
ggml_view_1d(ctx0, kv_self.k, (n_past + N) * n_head_kv * n_embd_head,
2518-
il * n_ctx *
2519-
ggml_element_size(kv_self.k) *
2520-
n_head_kv *
2521-
n_embd_head),
2522-
n_embd_head, n_head_kv, n_past + N),
2523-
0, 2, 1, 3);
2511+
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
2512+
ggml_set_name(Q, "Q");
25242513

2525-
// K * Q
2514+
struct ggml_tensor * K =
2515+
ggml_view_3d(ctx0, kv_self.k,
2516+
n_embd_head, n_past + N, n_head_kv,
2517+
ggml_element_size(kv_self.k)*n_embd_gqa,
2518+
ggml_element_size(kv_self.k)*n_embd_head,
2519+
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
2520+
ggml_set_name(K, "K");
25262521

2527-
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
25282522
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2523+
ggml_set_name(KQ, "KQ");
25292524

2530-
// KQ_scaled = KQ / sqrt(n_embd/n_head)
25312525
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
2526+
ggml_set_name(KQ_scaled, "KQ_scaled");
25322527

2533-
// KQ_masked = mask_past(KQ_scaled)
25342528
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
2529+
ggml_set_name(KQ_masked, "KQ_masked");
25352530

2536-
// KQ = soft_max(KQ_masked)
25372531
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
2532+
ggml_set_name(KQ_soft_max, "KQ_soft_max");
2533+
2534+
struct ggml_tensor * V =
2535+
ggml_view_3d(ctx0, kv_self.v,
2536+
n_past + N, n_embd_head, n_head_kv,
2537+
ggml_element_size(kv_self.v)*n_ctx,
2538+
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
2539+
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
2540+
ggml_set_name(V, "V");
25382541

2539-
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
2540-
struct ggml_tensor* V = ggml_permute(
2541-
ctx0,
2542-
ggml_reshape_3d(
2543-
ctx0,
2544-
ggml_view_1d(ctx0, kv_self.v, (n_past + N) * n_head_kv * n_embd_head,
2545-
il * n_ctx *
2546-
ggml_element_size(kv_self.v) *
2547-
n_head_kv *
2548-
n_embd_head),
2549-
n_embd_head, n_head_kv, n_past + N),
2550-
0, 2, 1, 3);
2551-
2552-
V = ggml_cont(ctx0, ggml_transpose(ctx0, V));
2553-
2554-
// KQV = transpose(V) * KQ_soft_max
25552542
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2543+
ggml_set_name(KQV, "KQV");
25562544

2557-
// KQV_merged = KQV.permute(0, 2, 1, 3)
25582545
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2546+
ggml_set_name(KQV_merged, "KQV_merged");
25592547

2560-
// cur = KQV_merged.contiguous().view(n_embd, N)
25612548
cur = ggml_cpy(ctx0,
25622549
KQV_merged,
25632550
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
2551+
ggml_set_name(cur, "KQV_merged_contiguous");
25642552

2565-
// projection
2566-
{
2567-
cur = ggml_mul_mat(ctx0,
2568-
model.layers[il].wo,
2569-
cur);
2570-
}
2553+
cur = ggml_cpy(ctx0,
2554+
KQV_merged,
2555+
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
2556+
2557+
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
2558+
ggml_set_name(cur, "result_wo");
25712559
}
25722560

25732561
struct ggml_tensor * inpFF = attn_norm;
25742562
struct ggml_tensor * attn_out = ggml_cpy(
25752563
ctx0, cur, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
25762564

2577-
{
2578-
cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF);
2579-
cur = ggml_gelu(ctx0, cur);
2580-
cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
2581-
}
2565+
cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF);
2566+
cur = ggml_gelu(ctx0, cur);
2567+
cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
25822568

25832569
cur = ggml_add(ctx0, cur, attn_out);
25842570
cur = ggml_add(ctx0, cur, inpL);
2571+
25852572
// input for next layer
25862573
inpL = cur;
25872574
}

0 commit comments

Comments
 (0)