Skip to content

batch : rework llama_batch_allocr #14153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 66 additions & 2 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#include "llama-batch.h"

#include "llama-impl.h"
#include "llama-cparams.h"
#include "llama-vocab.h"

#include <cassert>
#include <cstring>
#include <algorithm>
Expand Down Expand Up @@ -279,9 +283,42 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
);
}

llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
batch = in_batch;
llama_batch_allocr::llama_batch_allocr() = default;

bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
clear();

batch = batch_inp;

GGML_ASSERT(batch.n_tokens > 0);

if (!batch.pos) {
if (batch.seq_id) {
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
return false;
}
}

if (batch.token) {
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return false;
}
}
}

if (batch.seq_id) {
for (int32_t i = 0; i < batch.n_tokens; ++i) {
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
return false;
}
}
}
}

if (!batch.pos) {
assert(p0 >= 0);
pos.resize(batch.n_tokens);
Expand All @@ -290,13 +327,15 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
}
batch.pos = pos.data();
}

if (!batch.n_seq_id) {
n_seq_id.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
n_seq_id[i] = seq_id_0.size();
}
batch.n_seq_id = n_seq_id.data();
}

if (!batch.seq_id) {
seq_id.resize(batch.n_tokens + 1);
seq_id[batch.n_tokens] = NULL;
Expand All @@ -305,12 +344,37 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
}
batch.seq_id = seq_id.data();
}

if (!batch.logits) {
// by default return the output only for the last token
output.resize(batch.n_tokens);
output[output.size() - 1] = true;
batch.logits = output.data();
}

for (int32_t i = 0; i < batch.n_tokens; ++i) {
n_outputs += batch.logits[i] != 0;
}

return true;
}

const llama_batch & llama_batch_allocr::get_batch() const {
return batch;
}

uint32_t llama_batch_allocr::get_n_outputs() const {
return n_outputs;
}

void llama_batch_allocr::clear() {
n_outputs = 0;

batch = {};
pos.clear();
n_seq_id.clear();
seq_id.clear();
output.clear();
}

//
Expand Down
27 changes: 20 additions & 7 deletions src/llama-batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ struct llama_ubatch {
llama_token * token; // [n_tokens]
float * embd; // [n_embd, n_tokens]
llama_pos * pos; // [n_tokens]
int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
Comment on lines -21 to -22
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decided against these TODOs because multiple sequences per input token actually has some useful properties that cannot be achieved otherwise (for example see the hellaswag usage). Instead, will add logic to guarantee that the provided ids are valid, utilizing the memory's seq_pos_min() and seq_pos_max() methods.

int32_t * n_seq_id; // [n_seqs]
llama_seq_id ** seq_id; // [n_seqs]
int8_t * output; // [n_tokens]
};

Expand Down Expand Up @@ -78,15 +78,28 @@ struct llama_sbatch {
};

// temporary allocate memory for the input batch if needed
struct llama_batch_allocr {
struct llama_batch batch;
class llama_batch_allocr {
public:
llama_batch_allocr();

// optionally fulfill the batch returned by llama_batch_get_one
bool init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0);

const llama_batch & get_batch() const;

uint32_t get_n_outputs() const;

private:
void clear();

llama_batch batch;

uint32_t n_outputs;

std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id

std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> output;

// optionally fulfill the batch returned by llama_batch_get_one
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
};
Loading
Loading