|
28 | 28 | #define LLAMA_USE_SCRATCH
|
29 | 29 | #define LLAMA_MAX_SCRATCH_BUFFERS 16
|
30 | 30 |
|
| 31 | +#define LLAMA_USE_FLASH_ATTN |
| 32 | + |
31 | 33 | #define LLAMA_ASSERT(x) \
|
32 | 34 | do { \
|
33 | 35 | if (!(x)) { \
|
@@ -829,6 +831,30 @@ static bool llama_eval_internal(
|
829 | 831 | ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
830 | 832 | }
|
831 | 833 |
|
| 834 | +#ifdef LLAMA_USE_FLASH_ATTN |
| 835 | + struct ggml_tensor * Q = |
| 836 | + ggml_permute(ctx0, |
| 837 | + ggml_cpy(ctx0, |
| 838 | + Qcur, |
| 839 | + ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_embd/n_head, n_head, N)), |
| 840 | + 0, 2, 1, 3); |
| 841 | + |
| 842 | + struct ggml_tensor * K = |
| 843 | + ggml_permute(ctx0, |
| 844 | + ggml_reshape_3d(ctx0, |
| 845 | + ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd), |
| 846 | + n_embd/n_head, n_head, n_past + N), |
| 847 | + 0, 2, 1, 3); |
| 848 | + |
| 849 | + struct ggml_tensor * V = |
| 850 | + ggml_view_3d(ctx0, kv_self.v, |
| 851 | + n_past + N, n_embd/n_head, n_head, |
| 852 | + n_ctx*ggml_element_size(kv_self.v), |
| 853 | + n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head, |
| 854 | + il*n_ctx*ggml_element_size(kv_self.v)*n_embd); |
| 855 | + |
| 856 | + struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true); |
| 857 | +#else |
832 | 858 | struct ggml_tensor * Q =
|
833 | 859 | ggml_permute(ctx0,
|
834 | 860 | Qcur,
|
@@ -872,6 +898,7 @@ static bool llama_eval_internal(
|
872 | 898 | // is there a better way?
|
873 | 899 | struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
|
874 | 900 | struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
|
| 901 | +#endif |
875 | 902 | #endif
|
876 | 903 |
|
877 | 904 | // KQV_merged = KQV.permute(0, 2, 1, 3)
|
|
0 commit comments