Skip to content

Commit ebf093a

Browse files
committed
opt : fix n_outputs
ggml-ci
1 parent b7ee380 commit ebf093a

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

src/llama-context.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1954,6 +1954,17 @@ void llama_context::opt_epoch_iter(
19541954
//}
19551955
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
19561956

1957+
n_outputs = ubatch.n_tokens;
1958+
1959+
printf("ubatch.n_tokens = %d\n", ubatch.n_tokens);
1960+
1961+
// TODO: not sure if this is needed
1962+
if (!kv_self->find_slot(ubatch)) {
1963+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1964+
1965+
GGML_ABORT("TODO: handle this error");
1966+
}
1967+
19571968
auto * gf = graph_init();
19581969
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
19591970

@@ -1969,7 +1980,7 @@ void llama_context::opt_epoch_iter(
19691980
};
19701981
ctx_compute_opt = ggml_init(params);
19711982
}
1972-
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), ggml_graph_node(gf, -1));
1983+
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
19731984
ggml_opt_alloc(opt_ctx, train);
19741985
//llama_set_inputs(*lctx, ubatch);
19751986
res->set_inputs(&ubatch);

0 commit comments

Comments
 (0)