Skip to content

Commit 33309f6

Browse files
llama : check all graph nodes when searching for result_embd_pooled (#8956)
Co-authored-by: Stanisław Szymczyk <[email protected]>
1 parent 7c5bfd5 commit 33309f6

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

src/llama.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14722,12 +14722,15 @@ static int llama_decode_internal(
1472214722
res = nullptr;
1472314723
embd = nullptr;
1472414724
} else if (cparams.embeddings) {
14725-
res = nullptr; // do not extract logits for embedding case
14726-
embd = gf->nodes[gf->n_nodes - 1];
14727-
if (strcmp(embd->name, "result_embd_pooled") != 0) {
14728-
embd = gf->nodes[gf->n_nodes - 2];
14725+
res = nullptr; // do not extract logits for embedding case
14726+
embd = nullptr;
14727+
for (int i = gf->n_nodes - 1; i >= 0; --i) {
14728+
if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
14729+
embd = gf->nodes[i];
14730+
break;
14731+
}
1472914732
}
14730-
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
14733+
GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
1473114734
} else {
1473214735
embd = nullptr; // do not extract embeddings when not needed
1473314736
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");

0 commit comments

Comments
 (0)