Skip to content

Commit 6369f86

Browse files
authored
llama : rename missed batch params/vars to ubatch (#10059)
This commit renames the `batch` parameter to `ubatch` in the `llama_kv_cache_find_slot`, `llm_build_inp_embd`, and `llm_build_mamba` functions. The motivation for this is that this should have been done as part of Commit 19d900a ("llama : rename batch to ubatch (#9950)") but for some reason I missed these functions in that commit and only noticed them now (sorry).
1 parent 47182dd commit 6369f86

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

src/llama-kv-cache.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -119,27 +119,27 @@ bool llama_kv_cache_init(
119119

120120
struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
121121
struct llama_kv_cache & cache,
122-
const struct llama_ubatch & batch) {
123-
const uint32_t n_tokens = batch.n_tokens;
124-
const uint32_t n_seqs = batch.n_seqs;
125-
const uint32_t n_seq_tokens = batch.n_seq_tokens;
122+
const struct llama_ubatch & ubatch) {
123+
const uint32_t n_tokens = ubatch.n_tokens;
124+
const uint32_t n_seqs = ubatch.n_seqs;
125+
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
126126

127127
if (cache.recurrent) {
128128
// For recurrent state architectures (like Mamba or RWKV),
129129
// each cache cell can store the state for a whole sequence.
130130
// A slot should be always be contiguous.
131131

132132
// can only process batches with an equal number of new tokens in each sequence
133-
GGML_ASSERT(batch.equal_seqs);
133+
GGML_ASSERT(ubatch.equal_seqs);
134134

135135
int32_t min = cache.size - 1;
136136
int32_t max = 0;
137137

138138
// everything should fit if all seq_ids are smaller than the max
139139
for (uint32_t s = 0; s < n_seqs; ++s) {
140-
const uint32_t n_seq_id = batch.n_seq_id[s];
140+
const uint32_t n_seq_id = ubatch.n_seq_id[s];
141141
for (uint32_t j = 0; j < n_seq_id; ++j) {
142-
const llama_seq_id seq_id = batch.seq_id[s][j];
142+
const llama_seq_id seq_id = ubatch.seq_id[s][j];
143143

144144
if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
145145
// too big seq_id
@@ -198,7 +198,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
198198

199199
// find usable cell range
200200
for (uint32_t s = 0; s < n_seqs; ++s) {
201-
const llama_seq_id seq_id = batch.seq_id[s][0];
201+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
202202
llama_kv_cell & seq_meta = cache.cells[seq_id];
203203
bool has_cell = false;
204204
if (seq_meta.tail >= 0) {
@@ -237,7 +237,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
237237
// gather and re-order
238238
for (uint32_t s = 0; s < n_seqs; ++s) {
239239
int32_t dst_id = s + min;
240-
int32_t src_id = cache.cells[batch.seq_id[s][0]].tail;
240+
int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail;
241241
if (dst_id != src_id) {
242242
llama_kv_cell & dst_cell = cache.cells[dst_id];
243243
llama_kv_cell & src_cell = cache.cells[src_id];
@@ -258,20 +258,20 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
258258

259259
// update the pos of the used seqs
260260
for (uint32_t s = 0; s < n_seqs; ++s) {
261-
const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1];
261+
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
262262
int32_t cell_id = s + min;
263263
llama_kv_cell & cell = cache.cells[cell_id];
264264

265265
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
266266
// What should happen when the pos backtracks or skips a value?
267267
// Clearing the state mid-batch would require special-casing which isn't done.
268268
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
269-
__func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
269+
__func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
270270
}
271271
cell.pos = last_pos;
272272
cell.seq_id.clear();
273-
for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) {
274-
const llama_seq_id seq_id = batch.seq_id[s][j];
273+
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
274+
const llama_seq_id seq_id = ubatch.seq_id[s][j];
275275
cell.seq_id.insert(seq_id);
276276
cache.cells[seq_id].tail = cell_id;
277277
}
@@ -325,10 +325,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
325325
for (uint32_t s = 0; s < n_seqs; s++) {
326326
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
327327
uint32_t k = s*n_seq_tokens + i;
328-
cache.cells[cache.head + k].pos = batch.pos[k];
328+
cache.cells[cache.head + k].pos = ubatch.pos[k];
329329

330-
for (int32_t j = 0; j < batch.n_seq_id[s]; j++) {
331-
cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]);
330+
for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
331+
cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]);
332332
}
333333
}
334334
}

src/llama.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2540,21 +2540,21 @@ static struct ggml_tensor * llm_build_inp_embd(
25402540
struct ggml_context * ctx,
25412541
struct llama_context & lctx,
25422542
const llama_hparams & hparams,
2543-
const llama_ubatch & batch,
2543+
const llama_ubatch & ubatch,
25442544
struct ggml_tensor * tok_embd,
25452545
const llm_build_cb & cb) {
25462546
const int64_t n_embd = hparams.n_embd;
25472547

25482548
struct ggml_tensor * inpL;
25492549

2550-
if (batch.token) {
2551-
lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens);
2550+
if (ubatch.token) {
2551+
lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch.n_tokens);
25522552
cb(lctx.inp_tokens, "inp_tokens", -1);
25532553
ggml_set_input(lctx.inp_tokens);
25542554

25552555
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
25562556
} else {
2557-
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
2557+
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
25582558
inpL = lctx.inp_embd;
25592559
ggml_set_input(lctx.inp_embd);
25602560
}
@@ -3149,7 +3149,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
31493149
static struct ggml_tensor * llm_build_mamba(
31503150
struct ggml_context * ctx,
31513151
struct llama_context & lctx,
3152-
const llama_ubatch & batch,
3152+
const llama_ubatch & ubatch,
31533153
struct ggml_cgraph * graph,
31543154
struct ggml_tensor * cur,
31553155
struct ggml_tensor * state_copy,
@@ -3165,17 +3165,17 @@ static struct ggml_tensor * llm_build_mamba(
31653165
const int64_t d_inner = hparams.ssm_d_inner;
31663166
const int64_t d_state = hparams.ssm_d_state;
31673167
const int64_t dt_rank = hparams.ssm_dt_rank;
3168-
const int64_t n_seqs = batch.n_seqs;
3168+
const int64_t n_seqs = ubatch.n_seqs;
31693169
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
31703170
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
31713171
// Use the same RMS norm as the final layer norm
31723172
const float norm_rms_eps = hparams.f_norm_rms_eps;
31733173

3174-
const int64_t n_seq_tokens = batch.n_seq_tokens;
3174+
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
31753175

31763176
GGML_ASSERT(n_seqs != 0);
3177-
GGML_ASSERT(batch.equal_seqs);
3178-
GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs);
3177+
GGML_ASSERT(ubatch.equal_seqs);
3178+
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
31793179

31803180
struct ggml_tensor * conv_states_all = kv.k_l[il];
31813181
struct ggml_tensor * ssm_states_all = kv.v_l[il];

0 commit comments

Comments
 (0)