Skip to content

Commit 5a60db5

Browse files
committed
feat: Zero-out recurrent / non-recurrent layers in the single-type caches
This is a bit of an inversion of concerns, so we could conceivably make the interface to this more opaque to the other cache types by providing something like a layer mask, but since these cache implementations already have access to the hparams, it seems minimally invasive to just check the new recurrent_layer function. Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 8a13b03 commit 5a60db5

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/llama-kv-cache.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,11 @@ llama_kv_cache_unified::llama_kv_cache_unified(
100100
throw std::runtime_error("failed to create ggml context for kv cache");
101101
}
102102

103-
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
104-
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
103+
// any recurrent layers in the model will not use this cache
104+
const uint32_t tensor_dim = hparams.recurrent_layer(i) ? 0 : kv_size;
105+
106+
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*tensor_dim);
107+
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*tensor_dim);
105108
ggml_format_name(k, "cache_k_l%d", i);
106109
ggml_format_name(v, "cache_v_l%d", i);
107110
k_l.push_back(k);
@@ -1447,8 +1450,11 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
14471450
throw std::runtime_error("failed to create ggml context for kv cache");
14481451
}
14491452

1450-
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
1451-
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
1453+
// any non-recurrent layers in the model will not use this cache
1454+
const uint32_t tensor_dim = hparams.recurrent_layer(i) ? kv_size : 0;
1455+
1456+
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*tensor_dim);
1457+
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*tensor_dim);
14521458
ggml_format_name(k, "cache_k_l%d", i);
14531459
ggml_format_name(v, "cache_v_l%d", i);
14541460
k_l.push_back(k);

0 commit comments

Comments
 (0)