Skip to content

Commit c678901

Browse files
committed
fix: Give ownership of child caches to the hybrid cache
The parent should fully own the lifecycle of the children which is managed by the m_children member holding unique_ptrs. These need to be initialized correctly, so the constructor now takes the input vector of child_cache by value instead of reference so that the child pointers can be transferred to the parent cache. The expectation is that the vector of child_cache instances will be instantiated in-place with move semantics. Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent a99cbd3 commit c678901

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

src/llama-kv-cache.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2427,23 +2427,23 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
24272427
// llama_kv_cache_hybrid
24282428
//
24292429
llama_kv_cache_hybrid::llama_kv_cache_hybrid(
2430-
const llama_hparams & hparams,
2431-
const std::vector<child_cache> & children) :
2430+
const llama_hparams & hparams,
2431+
std::vector<child_cache> children) :
24322432
m_hparams(hparams),
24332433
m_layer_cache_map(
24342434
[](const std::vector<child_cache>& caches) -> std::unordered_map<size_t, llama_kv_cache*> {
24352435
std::unordered_map<size_t, llama_kv_cache*> map;
24362436
for (const auto & cache : caches) {
24372437
for (size_t layer_id : cache.layer_ids) {
2438-
map[layer_id] = cache.child;
2438+
map[layer_id] = cache.child.get();
24392439
}
24402440
}
24412441

24422442
return map;
24432443
}(children)
24442444
),
24452445
m_children(
2446-
[](std::vector<child_cache> caches) -> std::set<llama_kv_cache*> {
2446+
[](std::vector<child_cache>& caches) -> std::set<std::unique_ptr<llama_kv_cache>> {
24472447
// Sort the caches by the lowest layer ID so the order is repeatable
24482448
for (auto & cache : caches) {
24492449
GGML_ASSERT(cache.layer_ids.size() > 0);
@@ -2452,22 +2452,22 @@ llama_kv_cache_hybrid::llama_kv_cache_hybrid(
24522452
std::sort(caches.begin(), caches.end(), [](const child_cache & a, const child_cache & b) {
24532453
return a.layer_ids[0] < b.layer_ids[0];
24542454
});
2455-
std::set<llama_kv_cache*> unique_caches;
2456-
for (const auto & cache : caches) {
2457-
unique_caches.insert(cache.child);
2455+
std::set<std::unique_ptr<llama_kv_cache>> unique_caches;
2456+
for (auto & cache : caches) {
2457+
unique_caches.emplace(cache.child.release());
24582458
}
24592459
return unique_caches;
24602460
}(children)
24612461
),
24622462
m_has_recurrent(
2463-
[](const std::vector<child_cache>& caches) -> bool {
2463+
[](const std::set<std::unique_ptr<llama_kv_cache>> & caches) -> bool {
24642464
for (const auto & cache : caches) {
2465-
if (dynamic_cast<llama_kv_cache_recurrent *>(cache.child)) {
2465+
if (dynamic_cast<llama_kv_cache_recurrent *>(cache.get())) {
24662466
return true;
24672467
}
24682468
}
24692469
return false;
2470-
}(children)
2470+
}(m_children)
24712471
)
24722472
{
24732473
// Ensure at least one child

src/llama-kv-cache.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -407,13 +407,16 @@ class llama_kv_cache_hybrid : public llama_kv_cache {
407407
public:
408408

409409
struct child_cache {
410-
llama_kv_cache * child;
411-
std::vector<size_t> layer_ids;
410+
std::unique_ptr<llama_kv_cache> child;
411+
std::vector<size_t> layer_ids;
412+
413+
child_cache(std::unique_ptr<llama_kv_cache> child_, std::vector<size_t> layer_ids_)
414+
: child(std::move(child_)), layer_ids(std::move(layer_ids_)) {}
412415
};
413416

414417
llama_kv_cache_hybrid(
415418
const llama_hparams & hparams,
416-
const std::vector<child_cache> & children);
419+
std::vector<child_cache> children);
417420

418421
//
419422
// llama_memory_i
@@ -470,7 +473,7 @@ class llama_kv_cache_hybrid : public llama_kv_cache {
470473

471474
const llama_hparams & m_hparams;
472475
const std::unordered_map<size_t, llama_kv_cache *> m_layer_cache_map;
473-
const std::set<llama_kv_cache *> m_children; // Ordered for state IO
476+
const std::set<std::unique_ptr<llama_kv_cache>> m_children; // Ordered for state IO
474477
const bool m_has_recurrent;
475478
};
476479

0 commit comments

Comments
 (0)