@@ -2201,10 +2201,7 @@ static struct ggml_cgraph * llm_build_llama(
2201
2201
ggml_build_forward_expand (gf, ggml_cpy (ctx0, Vcur, v));
2202
2202
}
2203
2203
2204
- struct ggml_tensor * Q =
2205
- ggml_permute (ctx0,
2206
- Qcur,
2207
- 0 , 2 , 1 , 3 );
2204
+ struct ggml_tensor * Q = ggml_permute (ctx0, Qcur, 0 , 2 , 1 , 3 );
2208
2205
offload_func_kq (Q);
2209
2206
ggml_set_name (Q, " Q" );
2210
2207
@@ -2381,7 +2378,7 @@ static struct ggml_cgraph * llm_build_falcon(
2381
2378
const int64_t n_head = hparams.n_head ;
2382
2379
const int64_t n_head_kv = hparams.n_head_kv ;
2383
2380
const int64_t n_embd_head = hparams.n_embd_head ();
2384
- // const int64_t n_embd_gqa = hparams.n_embd_gqa();
2381
+ const int64_t n_embd_gqa = hparams.n_embd_gqa ();
2385
2382
2386
2383
GGML_ASSERT (n_embd_head == hparams.n_rot );
2387
2384
@@ -2441,6 +2438,7 @@ static struct ggml_cgraph * llm_build_falcon(
2441
2438
struct ggml_tensor * attn_norm;
2442
2439
2443
2440
// self-attention
2441
+ // TODO: refactor into common function (shared with LLaMA)
2444
2442
{
2445
2443
attn_norm = ggml_norm (ctx0, inpL, norm_eps);
2446
2444
@@ -2473,115 +2471,104 @@ static struct ggml_cgraph * llm_build_falcon(
2473
2471
2474
2472
const size_t wsize = ggml_type_size (cur->type );
2475
2473
2476
- struct ggml_tensor * Qcur = ggml_view_3d (
2474
+ struct ggml_tensor * tmpq = ggml_view_3d (
2477
2475
ctx0, cur, n_embd_head, n_head, N,
2478
2476
wsize * n_embd_head,
2479
2477
wsize * n_embd_head * (n_head + 2 * n_head_kv),
2480
2478
0 );
2481
2479
2482
- struct ggml_tensor * Kcur = ggml_view_3d (
2480
+ struct ggml_tensor * tmpk = ggml_view_3d (
2483
2481
ctx0, cur, n_embd_head, n_head_kv, N,
2484
2482
wsize * n_embd_head,
2485
2483
wsize * n_embd_head * (n_head + 2 * n_head_kv),
2486
- wsize * n_embd_head * n_head);
2484
+ wsize * n_embd_head * n_head);
2487
2485
2488
- struct ggml_tensor * Vcur = ggml_view_3d (
2486
+ struct ggml_tensor * tmpv = ggml_view_3d (
2489
2487
ctx0, cur, n_embd_head, n_head_kv, N,
2490
2488
wsize * n_embd_head,
2491
2489
wsize * n_embd_head * (n_head + 2 * n_head_kv),
2492
2490
wsize * n_embd_head * (n_head + n_head_kv));
2493
2491
2494
2492
// using mode = 2 for neox mode
2495
- Qcur = ggml_rope_custom_inplace (ctx0, Qcur , n_past, n_embd_head, 2 , 0 , freq_base, freq_scale);
2496
- Kcur = ggml_rope_custom_inplace (ctx0, Kcur , n_past, n_embd_head, 2 , 0 , freq_base, freq_scale);
2493
+ struct ggml_tensor * Qcur = ggml_rope_custom_inplace (ctx0, tmpq , n_past, n_embd_head, 2 , 0 , freq_base, freq_scale);
2494
+ struct ggml_tensor * Kcur = ggml_rope_custom_inplace (ctx0, tmpk , n_past, n_embd_head, 2 , 0 , freq_base, freq_scale);
2497
2495
2498
- // store key and value to memory
2499
2496
{
2500
- struct ggml_tensor * k = ggml_view_1d (
2501
- ctx0, kv_self.k , N * n_head_kv * n_embd_head,
2502
- (ggml_element_size (kv_self.k ) * n_head_kv * n_embd_head) *
2503
- (il * n_ctx + n_past));
2504
- struct ggml_tensor * v = ggml_view_1d (
2505
- ctx0, kv_self.v , N * n_head_kv * n_embd_head,
2506
- (ggml_element_size (kv_self.v ) * n_head_kv * n_embd_head) *
2507
- (il * n_ctx + n_past));
2497
+ struct ggml_tensor * Vcur = ggml_transpose (ctx0, ggml_reshape_2d (ctx0, ggml_cont (ctx0, tmpv), n_embd_gqa, N));
2498
+ ggml_set_name (Vcur, " Vcur" );
2499
+
2500
+ struct ggml_tensor * k = ggml_view_1d (ctx0, kv_self.k , N*n_embd_gqa, (ggml_element_size (kv_self.k )*n_embd_gqa)*(il*n_ctx + n_past));
2501
+ ggml_set_name (k, " k" );
2502
+
2503
+ struct ggml_tensor * v = ggml_view_2d (ctx0, kv_self.v , N, n_embd_gqa,
2504
+ ( n_ctx)*ggml_element_size (kv_self.v ),
2505
+ (il*n_ctx)*ggml_element_size (kv_self.v )*n_embd_gqa + n_past*ggml_element_size (kv_self.v ));
2508
2506
2509
2507
ggml_build_forward_expand (gf, ggml_cpy (ctx0, Kcur, k));
2510
2508
ggml_build_forward_expand (gf, ggml_cpy (ctx0, Vcur, v));
2511
2509
}
2512
2510
2513
- struct ggml_tensor * K = ggml_permute (
2514
- ctx0,
2515
- ggml_reshape_3d (
2516
- ctx0,
2517
- ggml_view_1d (ctx0, kv_self.k , (n_past + N) * n_head_kv * n_embd_head,
2518
- il * n_ctx *
2519
- ggml_element_size (kv_self.k ) *
2520
- n_head_kv *
2521
- n_embd_head),
2522
- n_embd_head, n_head_kv, n_past + N),
2523
- 0 , 2 , 1 , 3 );
2511
+ struct ggml_tensor * Q = ggml_permute (ctx0, Qcur, 0 , 2 , 1 , 3 );
2512
+ ggml_set_name (Q, " Q" );
2524
2513
2525
- // K * Q
2514
+ struct ggml_tensor * K =
2515
+ ggml_view_3d (ctx0, kv_self.k ,
2516
+ n_embd_head, n_past + N, n_head_kv,
2517
+ ggml_element_size (kv_self.k )*n_embd_gqa,
2518
+ ggml_element_size (kv_self.k )*n_embd_head,
2519
+ ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il);
2520
+ ggml_set_name (K, " K" );
2526
2521
2527
- struct ggml_tensor * Q = ggml_permute (ctx0, Qcur, 0 , 2 , 1 , 3 );
2528
2522
struct ggml_tensor * KQ = ggml_mul_mat (ctx0, K, Q);
2523
+ ggml_set_name (KQ, " KQ" );
2529
2524
2530
- // KQ_scaled = KQ / sqrt(n_embd/n_head)
2531
2525
struct ggml_tensor * KQ_scaled = ggml_scale_inplace (ctx0, KQ, KQ_scale);
2526
+ ggml_set_name (KQ_scaled, " KQ_scaled" );
2532
2527
2533
- // KQ_masked = mask_past(KQ_scaled)
2534
2528
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace (ctx0, KQ_scaled, n_past);
2529
+ ggml_set_name (KQ_masked, " KQ_masked" );
2535
2530
2536
- // KQ = soft_max(KQ_masked)
2537
2531
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace (ctx0, KQ_masked);
2532
+ ggml_set_name (KQ_soft_max, " KQ_soft_max" );
2533
+
2534
+ struct ggml_tensor * V =
2535
+ ggml_view_3d (ctx0, kv_self.v ,
2536
+ n_past + N, n_embd_head, n_head_kv,
2537
+ ggml_element_size (kv_self.v )*n_ctx,
2538
+ ggml_element_size (kv_self.v )*n_ctx*n_embd_head,
2539
+ ggml_element_size (kv_self.v )*n_ctx*n_embd_gqa*il);
2540
+ ggml_set_name (V, " V" );
2538
2541
2539
- // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
2540
- struct ggml_tensor * V = ggml_permute (
2541
- ctx0,
2542
- ggml_reshape_3d (
2543
- ctx0,
2544
- ggml_view_1d (ctx0, kv_self.v , (n_past + N) * n_head_kv * n_embd_head,
2545
- il * n_ctx *
2546
- ggml_element_size (kv_self.v ) *
2547
- n_head_kv *
2548
- n_embd_head),
2549
- n_embd_head, n_head_kv, n_past + N),
2550
- 0 , 2 , 1 , 3 );
2551
-
2552
- V = ggml_cont (ctx0, ggml_transpose (ctx0, V));
2553
-
2554
- // KQV = transpose(V) * KQ_soft_max
2555
2542
struct ggml_tensor * KQV = ggml_mul_mat (ctx0, V, KQ_soft_max);
2543
+ ggml_set_name (KQV, " KQV" );
2556
2544
2557
- // KQV_merged = KQV.permute(0, 2, 1, 3)
2558
2545
struct ggml_tensor * KQV_merged = ggml_permute (ctx0, KQV, 0 , 2 , 1 , 3 );
2546
+ ggml_set_name (KQV_merged, " KQV_merged" );
2559
2547
2560
- // cur = KQV_merged.contiguous().view(n_embd, N)
2561
2548
cur = ggml_cpy (ctx0,
2562
2549
KQV_merged,
2563
2550
ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, N));
2551
+ ggml_set_name (cur, " KQV_merged_contiguous" );
2564
2552
2565
- // projection
2566
- {
2567
- cur = ggml_mul_mat (ctx0,
2568
- model. layers [il]. wo ,
2569
- cur);
2570
- }
2553
+ cur = ggml_cpy (ctx0,
2554
+ KQV_merged,
2555
+ ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, N));
2556
+
2557
+ cur = ggml_mul_mat (ctx0, model. layers [il]. wo , cur);
2558
+ ggml_set_name (cur, " result_wo " );
2571
2559
}
2572
2560
2573
2561
struct ggml_tensor * inpFF = attn_norm;
2574
2562
struct ggml_tensor * attn_out = ggml_cpy (
2575
2563
ctx0, cur, ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, N));
2576
2564
2577
- {
2578
- cur = ggml_mul_mat (ctx0, model.layers [il].w3 , inpFF);
2579
- cur = ggml_gelu (ctx0, cur);
2580
- cur = ggml_mul_mat (ctx0, model.layers [il].w2 , cur);
2581
- }
2565
+ cur = ggml_mul_mat (ctx0, model.layers [il].w3 , inpFF);
2566
+ cur = ggml_gelu (ctx0, cur);
2567
+ cur = ggml_mul_mat (ctx0, model.layers [il].w2 , cur);
2582
2568
2583
2569
cur = ggml_add (ctx0, cur, attn_out);
2584
2570
cur = ggml_add (ctx0, cur, inpL);
2571
+
2585
2572
// input for next layer
2586
2573
inpL = cur;
2587
2574
}
0 commit comments