Skip to content

Commit 36ddd12

Browse files
committed
llama : add flash attention (demo)
1 parent 986b6ce commit 36ddd12

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

llama.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
#define LLAMA_USE_SCRATCH
2929
#define LLAMA_MAX_SCRATCH_BUFFERS 16
3030

31+
#define LLAMA_USE_FLASH_ATTN
32+
3133
#define LLAMA_ASSERT(x) \
3234
do { \
3335
if (!(x)) { \
@@ -829,6 +831,30 @@ static bool llama_eval_internal(
829831
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
830832
}
831833

834+
#ifdef LLAMA_USE_FLASH_ATTN
835+
struct ggml_tensor * Q =
836+
ggml_permute(ctx0,
837+
ggml_cpy(ctx0,
838+
Qcur,
839+
ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_embd/n_head, n_head, N)),
840+
0, 2, 1, 3);
841+
842+
struct ggml_tensor * K =
843+
ggml_permute(ctx0,
844+
ggml_reshape_3d(ctx0,
845+
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
846+
n_embd/n_head, n_head, n_past + N),
847+
0, 2, 1, 3);
848+
849+
struct ggml_tensor * V =
850+
ggml_view_3d(ctx0, kv_self.v,
851+
n_past + N, n_embd/n_head, n_head,
852+
n_ctx*ggml_element_size(kv_self.v),
853+
n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
854+
il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
855+
856+
struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
857+
#else
832858
struct ggml_tensor * Q =
833859
ggml_permute(ctx0,
834860
Qcur,
@@ -872,6 +898,7 @@ static bool llama_eval_internal(
872898
// is there a better way?
873899
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
874900
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
901+
#endif
875902
#endif
876903

877904
// KQV_merged = KQV.permute(0, 2, 1, 3)

0 commit comments

Comments
 (0)