@@ -1954,6 +1954,17 @@ void llama_context::opt_epoch_iter(
1954
1954
// }
1955
1955
llama_ubatch ubatch = kv_self->ubatch_next (sbatch, cparams.n_ubatch , embd_pooled);
1956
1956
1957
+ n_outputs = ubatch.n_tokens ;
1958
+
1959
+ printf (" ubatch.n_tokens = %d\n " , ubatch.n_tokens );
1960
+
1961
+ // TODO: not sure if this is needed
1962
+ if (!kv_self->find_slot (ubatch)) {
1963
+ LLAMA_LOG_WARN (" %s: failed to find KV cache slot for ubatch of size %d\n " , __func__, ubatch.n_tokens );
1964
+
1965
+ GGML_ABORT (" TODO: handle this error" );
1966
+ }
1967
+
1957
1968
auto * gf = graph_init ();
1958
1969
auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
1959
1970
@@ -1969,7 +1980,7 @@ void llama_context::opt_epoch_iter(
1969
1980
};
1970
1981
ctx_compute_opt = ggml_init (params);
1971
1982
}
1972
- ggml_opt_prepare_alloc (opt_ctx, ctx_compute_opt, gf, res->get_tokens (), ggml_graph_node (gf, - 1 ));
1983
+ ggml_opt_prepare_alloc (opt_ctx, ctx_compute_opt, gf, res->get_tokens (), res-> get_logits ( ));
1973
1984
ggml_opt_alloc (opt_ctx, train);
1974
1985
// llama_set_inputs(*lctx, ubatch);
1975
1986
res->set_inputs (&ubatch);
0 commit comments