Skip to content

Commit c76851e

Browse files
committed
llama : rename missed batch params/vars to ubatch
This commit renames the `batch` parameter to `ubatch` in the `llama_kv_cache_find_slot`, `llm_build_inp_embd`, and `llm_build_mamba` functions. The motivation for this is that this should have been done as part of Commit 19d900a ("llama : rename batch to ubatch (#9950)") but for some reason I missed these functions in that commit and only noticed them now (sorry).
1 parent cc2983d commit c76851e

File tree

1 file changed

+35
-35
lines changed

1 file changed

+35
-35
lines changed

src/llama.cpp

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3591,27 +3591,27 @@ static bool llama_kv_cache_init(
35913591
// to the first cell of the slot.
35923592
static bool llama_kv_cache_find_slot(
35933593
struct llama_kv_cache & cache,
3594-
const struct llama_ubatch & batch) {
3595-
const uint32_t n_tokens = batch.n_tokens;
3596-
const uint32_t n_seqs = batch.n_seqs;
3597-
const uint32_t n_seq_tokens = batch.n_seq_tokens;
3594+
const struct llama_ubatch & ubatch) {
3595+
const uint32_t n_tokens = ubatch.n_tokens;
3596+
const uint32_t n_seqs = ubatch.n_seqs;
3597+
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
35983598

35993599
if (cache.recurrent) {
36003600
// For recurrent state architectures (like Mamba or RWKV),
36013601
// each cache cell can store the state for a whole sequence.
36023602
// A slot should be always be contiguous.
36033603

36043604
// can only process batches with an equal number of new tokens in each sequence
3605-
GGML_ASSERT(batch.equal_seqs);
3605+
GGML_ASSERT(ubatch.equal_seqs);
36063606

36073607
int32_t min = cache.size - 1;
36083608
int32_t max = 0;
36093609

36103610
// everything should fit if all seq_ids are smaller than the max
36113611
for (uint32_t s = 0; s < n_seqs; ++s) {
3612-
const uint32_t n_seq_id = batch.n_seq_id[s];
3612+
const uint32_t n_seq_id = ubatch.n_seq_id[s];
36133613
for (uint32_t j = 0; j < n_seq_id; ++j) {
3614-
const llama_seq_id seq_id = batch.seq_id[s][j];
3614+
const llama_seq_id seq_id = ubatch.seq_id[s][j];
36153615

36163616
if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
36173617
// too big seq_id
@@ -3670,7 +3670,7 @@ static bool llama_kv_cache_find_slot(
36703670

36713671
// find usable cell range
36723672
for (uint32_t s = 0; s < n_seqs; ++s) {
3673-
const llama_seq_id seq_id = batch.seq_id[s][0];
3673+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
36743674
llama_kv_cell & seq_meta = cache.cells[seq_id];
36753675
bool has_cell = false;
36763676
if (seq_meta.tail >= 0) {
@@ -3709,7 +3709,7 @@ static bool llama_kv_cache_find_slot(
37093709
// gather and re-order
37103710
for (uint32_t s = 0; s < n_seqs; ++s) {
37113711
int32_t dst_id = s + min;
3712-
int32_t src_id = cache.cells[batch.seq_id[s][0]].tail;
3712+
int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail;
37133713
if (dst_id != src_id) {
37143714
llama_kv_cell & dst_cell = cache.cells[dst_id];
37153715
llama_kv_cell & src_cell = cache.cells[src_id];
@@ -3730,20 +3730,20 @@ static bool llama_kv_cache_find_slot(
37303730

37313731
// update the pos of the used seqs
37323732
for (uint32_t s = 0; s < n_seqs; ++s) {
3733-
const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1];
3733+
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
37343734
int32_t cell_id = s + min;
37353735
llama_kv_cell & cell = cache.cells[cell_id];
37363736

37373737
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
37383738
// What should happen when the pos backtracks or skips a value?
37393739
// Clearing the state mid-batch would require special-casing which isn't done.
37403740
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
3741-
__func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
3741+
__func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
37423742
}
37433743
cell.pos = last_pos;
37443744
cell.seq_id.clear();
3745-
for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) {
3746-
const llama_seq_id seq_id = batch.seq_id[s][j];
3745+
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
3746+
const llama_seq_id seq_id = ubatch.seq_id[s][j];
37473747
cell.seq_id.insert(seq_id);
37483748
cache.cells[seq_id].tail = cell_id;
37493749
}
@@ -3795,10 +3795,10 @@ static bool llama_kv_cache_find_slot(
37953795
for (uint32_t s = 0; s < n_seqs; s++) {
37963796
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
37973797
uint32_t k = s*n_seq_tokens + i;
3798-
cache.cells[cache.head + k].pos = batch.pos[k];
3798+
cache.cells[cache.head + k].pos = ubatch.pos[k];
37993799

3800-
for (int32_t j = 0; j < batch.n_seq_id[s]; j++) {
3801-
cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]);
3800+
for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
3801+
cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]);
38023802
}
38033803
}
38043804
}
@@ -9178,21 +9178,21 @@ static struct ggml_tensor * llm_build_inp_embd(
91789178
struct ggml_context * ctx,
91799179
struct llama_context & lctx,
91809180
const llama_hparams & hparams,
9181-
const llama_ubatch & batch,
9181+
const llama_ubatch & ubatch,
91829182
struct ggml_tensor * tok_embd,
91839183
const llm_build_cb & cb) {
91849184
const int64_t n_embd = hparams.n_embd;
91859185

91869186
struct ggml_tensor * inpL;
91879187

9188-
if (batch.token) {
9189-
lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens);
9188+
if (ubatch.token) {
9189+
lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch.n_tokens);
91909190
cb(lctx.inp_tokens, "inp_tokens", -1);
91919191
ggml_set_input(lctx.inp_tokens);
91929192

91939193
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
91949194
} else {
9195-
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
9195+
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
91969196
inpL = lctx.inp_embd;
91979197
ggml_set_input(lctx.inp_embd);
91989198
}
@@ -9766,7 +9766,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
97669766
static struct ggml_tensor * llm_build_mamba(
97679767
struct ggml_context * ctx,
97689768
struct llama_context & lctx,
9769-
const llama_ubatch & batch,
9769+
const llama_ubatch & ubatch,
97709770
struct ggml_cgraph * graph,
97719771
struct ggml_tensor * cur,
97729772
struct ggml_tensor * state_copy,
@@ -9782,17 +9782,17 @@ static struct ggml_tensor * llm_build_mamba(
97829782
const int64_t d_inner = hparams.ssm_d_inner;
97839783
const int64_t d_state = hparams.ssm_d_state;
97849784
const int64_t dt_rank = hparams.ssm_dt_rank;
9785-
const int64_t n_seqs = batch.n_seqs;
9785+
const int64_t n_seqs = ubatch.n_seqs;
97869786
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
97879787
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
97889788
// Use the same RMS norm as the final layer norm
97899789
const float norm_rms_eps = hparams.f_norm_rms_eps;
97909790

9791-
const int64_t n_seq_tokens = batch.n_seq_tokens;
9791+
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
97929792

97939793
GGML_ASSERT(n_seqs != 0);
9794-
GGML_ASSERT(batch.equal_seqs);
9795-
GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs);
9794+
GGML_ASSERT(ubatch.equal_seqs);
9795+
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
97969796

97979797
struct ggml_tensor * conv_states_all = kv.k_l[il];
97989798
struct ggml_tensor * ssm_states_all = kv.v_l[il];
@@ -20440,10 +20440,10 @@ struct llama_data_read {
2044020440

2044120441
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
2044220442

20443-
llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
20444-
batch.n_tokens = cell_count;
20445-
batch.n_seq_tokens = cell_count;
20446-
batch.n_seqs = 1;
20443+
llama_ubatch ubatch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
20444+
ubatch.n_tokens = cell_count;
20445+
ubatch.n_seq_tokens = cell_count;
20446+
ubatch.n_seqs = 1;
2044720447

2044820448
for (uint32_t i = 0; i < cell_count; ++i) {
2044920449
llama_pos pos;
@@ -20457,20 +20457,20 @@ struct llama_data_read {
2045720457
return false;
2045820458
}
2045920459

20460-
batch.pos[i] = pos;
20460+
ubatch.pos[i] = pos;
2046120461
}
20462-
batch.n_seq_id[0] = 1;
20463-
batch.seq_id[0] = &dest_seq_id;
20464-
if (!llama_kv_cache_find_slot(kv_self, batch)) {
20462+
ubatch.n_seq_id[0] = 1;
20463+
ubatch.seq_id[0] = &dest_seq_id;
20464+
if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
2046520465
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
2046620466
return false;
2046720467
}
2046820468

2046920469
// DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
2047020470
// Assume that this is one contiguous block of cells
2047120471
GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
20472-
GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
20473-
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
20472+
GGML_ASSERT(kv_self.cells[kv_self.head].pos == ubatch.pos[0]);
20473+
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == ubatch.pos[cell_count - 1]);
2047420474
GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
2047520475
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
2047620476
} else {

0 commit comments

Comments
 (0)