Skip to content

Commit ed7bb58

Browse files
committed
context : avoid passing unique_ptr
ggml-ci
1 parent 17809cf commit ed7bb58

File tree

3 files changed

+18
-16
lines changed

3 files changed

+18
-16
lines changed

src/llama-context.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -400,12 +400,12 @@ const llama_cparams & llama_context::get_cparams() const {
400400
return cparams;
401401
}
402402

403-
const ggml_backend_sched_ptr & llama_context::get_sched() const {
404-
return sched;
403+
ggml_backend_sched_t llama_context::get_sched() const {
404+
return sched.get();
405405
}
406406

407-
const ggml_context_ptr & llama_context::get_ctx_compute() const {
408-
return ctx_compute;
407+
ggml_context * llama_context::get_ctx_compute() const {
408+
return ctx_compute.get();
409409
}
410410

411411
const std::vector<ggml_backend_ptr> & llama_context::get_backends() const {

src/llama-context.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ struct llama_context {
3030
const llama_model & get_model() const;
3131
const llama_cparams & get_cparams() const;
3232

33-
const ggml_backend_sched_ptr & get_sched() const;
33+
ggml_backend_sched_t get_sched() const;
3434

35-
const ggml_context_ptr & get_ctx_compute() const;
35+
ggml_context * get_ctx_compute() const;
3636

37+
// TODO: this method might be possible to avoid (seach for TAG_BACKENDS)
3738
const std::vector<ggml_backend_ptr> & get_backends() const;
3839

3940
uint32_t n_ctx() const;

src/llama-kv-cache.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ void llama_kv_cache_unified::commit() {
371371
bool llama_kv_cache_unified::update(llama_context & lctx) {
372372
bool need_reserve = false;
373373

374-
const auto & sched = lctx.get_sched();
374+
auto * sched = lctx.get_sched();
375375

376376
if (has_shift) {
377377
if (!get_can_shift()) {
@@ -382,13 +382,13 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
382382

383383
// apply K-shift if needed
384384
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
385-
ggml_backend_sched_reset(sched.get());
385+
ggml_backend_sched_reset(sched);
386386

387387
auto * gf = lctx.graph_init();
388388

389389
auto res = build_graph_shift(lctx, gf);
390390

391-
ggml_backend_sched_alloc_graph(sched.get(), gf);
391+
ggml_backend_sched_alloc_graph(sched, gf);
392392

393393
res->set_inputs(nullptr);
394394

@@ -410,13 +410,13 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
410410
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
411411

412412
if (defrag_prepare(lctx.graph_max_nodes())) {
413-
ggml_backend_sched_reset(sched.get());
413+
ggml_backend_sched_reset(sched);
414414

415415
auto * gf = lctx.graph_init();
416416

417417
auto res = build_graph_defrag(lctx, gf);
418418

419-
ggml_backend_sched_alloc_graph(sched.get(), gf);
419+
ggml_backend_sched_alloc_graph(sched, gf);
420420

421421
res->set_inputs(nullptr);
422422

@@ -602,7 +602,8 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
602602
ggml_backend_buffer * bbuf) const {
603603
const auto & cparams = lctx.get_cparams();
604604
const auto & backends = lctx.get_backends();
605-
const auto & sched = lctx.get_sched();
605+
606+
auto * sched = lctx.get_sched();
606607

607608
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
608609

@@ -623,12 +624,12 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
623624
// dequantize to f32 -> RoPE -> quantize back
624625
tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
625626

626-
// TODO: can we simplify/avoid this?
627+
// TODO: can we simplify/avoid this? [TAG_BACKENDS]
627628
if (bbuf) {
628629
for (const auto & backend : backends) {
629630
// Figure out which backend KV cache belongs to
630631
if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) {
631-
ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
632+
ggml_backend_sched_set_tensor_backend(sched, tmp, backend.get());
632633
break;
633634
}
634635
}
@@ -680,7 +681,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
680681
ggml_cgraph * gf) const {
681682
auto res = std::make_unique<llm_graph_result>();
682683

683-
auto * ctx = lctx.get_ctx_compute().get();
684+
auto * ctx = lctx.get_ctx_compute();
684685

685686
const auto & cparams = lctx.get_cparams();
686687

@@ -733,7 +734,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
733734
ggml_cgraph * gf) const {
734735
auto res = std::make_unique<llm_graph_result>();
735736

736-
auto * ctx = lctx.get_ctx_compute().get();
737+
auto * ctx = lctx.get_ctx_compute();
737738

738739
const auto & ids = defrag_info.ids;
739740

0 commit comments

Comments
 (0)