Skip to content

Commit 6bf6a0a

Browse files
danbevNeoZhangJianyu
authored andcommitted
llama : rename missed batch params/vars to ubatch (ggml-org#10059)
This commit renames the `batch` parameter to `ubatch` in the `llama_kv_cache_find_slot`, `llm_build_inp_embd`, and `llm_build_mamba` functions. The motivation for this is that this should have been done as part of Commit 19d900a ("llama : rename batch to ubatch (ggml-org#9950)") but for some reason I missed these functions in that commit and only noticed them now (sorry).
1 parent 84a2e3d commit 6bf6a0a

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

src/llama-kv-cache.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -119,27 +119,27 @@ bool llama_kv_cache_init(
119119

120120
struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
121121
struct llama_kv_cache & cache,
122-
const struct llama_ubatch & batch) {
123-
const uint32_t n_tokens = batch.n_tokens;
124-
const uint32_t n_seqs = batch.n_seqs;
125-
const uint32_t n_seq_tokens = batch.n_seq_tokens;
122+
const struct llama_ubatch & ubatch) {
123+
const uint32_t n_tokens = ubatch.n_tokens;
124+
const uint32_t n_seqs = ubatch.n_seqs;
125+
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
126126

127127
if (cache.recurrent) {
128128
// For recurrent state architectures (like Mamba or RWKV),
129129
// each cache cell can store the state for a whole sequence.
130130
// A slot should be always be contiguous.
131131

132132
// can only process batches with an equal number of new tokens in each sequence
133-
GGML_ASSERT(batch.equal_seqs);
133+
GGML_ASSERT(ubatch.equal_seqs);
134134

135135
int32_t min = cache.size - 1;
136136
int32_t max = 0;
137137

138138
// everything should fit if all seq_ids are smaller than the max
139139
for (uint32_t s = 0; s < n_seqs; ++s) {
140-
const uint32_t n_seq_id = batch.n_seq_id[s];
140+
const uint32_t n_seq_id = ubatch.n_seq_id[s];
141141
for (uint32_t j = 0; j < n_seq_id; ++j) {
142-
const llama_seq_id seq_id = batch.seq_id[s][j];
142+
const llama_seq_id seq_id = ubatch.seq_id[s][j];
143143

144144
if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
145145
// too big seq_id
@@ -198,7 +198,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
198198

199199
// find usable cell range
200200
for (uint32_t s = 0; s < n_seqs; ++s) {
201-
const llama_seq_id seq_id = batch.seq_id[s][0];
201+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
202202
llama_kv_cell & seq_meta = cache.cells[seq_id];
203203
bool has_cell = false;
204204
if (seq_meta.tail >= 0) {
@@ -237,7 +237,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
237237
// gather and re-order
238238
for (uint32_t s = 0; s < n_seqs; ++s) {
239239
int32_t dst_id = s + min;
240-
int32_t src_id = cache.cells[batch.seq_id[s][0]].tail;
240+
int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail;
241241
if (dst_id != src_id) {
242242
llama_kv_cell & dst_cell = cache.cells[dst_id];
243243
llama_kv_cell & src_cell = cache.cells[src_id];
@@ -258,20 +258,20 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
258258

259259
// update the pos of the used seqs
260260
for (uint32_t s = 0; s < n_seqs; ++s) {
261-
const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1];
261+
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
262262
int32_t cell_id = s + min;
263263
llama_kv_cell & cell = cache.cells[cell_id];
264264

265265
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
266266
// What should happen when the pos backtracks or skips a value?
267267
// Clearing the state mid-batch would require special-casing which isn't done.
268268
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
269-
__func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
269+
__func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
270270
}
271271
cell.pos = last_pos;
272272
cell.seq_id.clear();
273-
for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) {
274-
const llama_seq_id seq_id = batch.seq_id[s][j];
273+
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
274+
const llama_seq_id seq_id = ubatch.seq_id[s][j];
275275
cell.seq_id.insert(seq_id);
276276
cache.cells[seq_id].tail = cell_id;
277277
}
@@ -325,10 +325,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
325325
for (uint32_t s = 0; s < n_seqs; s++) {
326326
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
327327
uint32_t k = s*n_seq_tokens + i;
328-
cache.cells[cache.head + k].pos = batch.pos[k];
328+
cache.cells[cache.head + k].pos = ubatch.pos[k];
329329

330-
for (int32_t j = 0; j < batch.n_seq_id[s]; j++) {
331-
cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]);
330+
for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
331+
cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]);
332332
}
333333
}
334334
}

src/llama.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2549,21 +2549,21 @@ static struct ggml_tensor * llm_build_inp_embd(
25492549
struct ggml_context * ctx,
25502550
struct llama_context & lctx,
25512551
const llama_hparams & hparams,
2552-
const llama_ubatch & batch,
2552+
const llama_ubatch & ubatch,
25532553
struct ggml_tensor * tok_embd,
25542554
const llm_build_cb & cb) {
25552555
const int64_t n_embd = hparams.n_embd;
25562556

25572557
struct ggml_tensor * inpL;
25582558

2559-
if (batch.token) {
2560-
lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens);
2559+
if (ubatch.token) {
2560+
lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch.n_tokens);
25612561
cb(lctx.inp_tokens, "inp_tokens", -1);
25622562
ggml_set_input(lctx.inp_tokens);
25632563

25642564
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
25652565
} else {
2566-
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
2566+
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
25672567
inpL = lctx.inp_embd;
25682568
ggml_set_input(lctx.inp_embd);
25692569
}
@@ -3158,7 +3158,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
31583158
static struct ggml_tensor * llm_build_mamba(
31593159
struct ggml_context * ctx,
31603160
struct llama_context & lctx,
3161-
const llama_ubatch & batch,
3161+
const llama_ubatch & ubatch,
31623162
struct ggml_cgraph * graph,
31633163
struct ggml_tensor * cur,
31643164
struct ggml_tensor * state_copy,
@@ -3174,17 +3174,17 @@ static struct ggml_tensor * llm_build_mamba(
31743174
const int64_t d_inner = hparams.ssm_d_inner;
31753175
const int64_t d_state = hparams.ssm_d_state;
31763176
const int64_t dt_rank = hparams.ssm_dt_rank;
3177-
const int64_t n_seqs = batch.n_seqs;
3177+
const int64_t n_seqs = ubatch.n_seqs;
31783178
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
31793179
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
31803180
// Use the same RMS norm as the final layer norm
31813181
const float norm_rms_eps = hparams.f_norm_rms_eps;
31823182

3183-
const int64_t n_seq_tokens = batch.n_seq_tokens;
3183+
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
31843184

31853185
GGML_ASSERT(n_seqs != 0);
3186-
GGML_ASSERT(batch.equal_seqs);
3187-
GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs);
3186+
GGML_ASSERT(ubatch.equal_seqs);
3187+
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
31883188

31893189
struct ggml_tensor * conv_states_all = kv.k_l[il];
31903190
struct ggml_tensor * ssm_states_all = kv.v_l[il];

0 commit comments

Comments
 (0)