Skip to content

Commit b4ee830

Browse files
committed
llama : rename missed batch params/vars to ubatch
This commit renames the `batch` parameter to `ubatch` in the `llama_kv_cache_find_slot`, `llm_build_inp_embd`, and `llm_build_mamba` functions. The motivation for this is that this should have been done as part of Commit 19d900a ("llama : rename batch to ubatch (ggml-org#9950)") but for some reason I missed these functions in that commit and only noticed them now (sorry).
1 parent 2f0ee84 commit b4ee830

File tree

1 file changed

+35
-35
lines changed

1 file changed

+35
-35
lines changed

src/llama.cpp

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3785,27 +3785,27 @@ static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
37853785
// to the first cell of the slot.
37863786
static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
37873787
struct llama_kv_cache & cache,
3788-
const struct llama_ubatch & batch) {
3789-
const uint32_t n_tokens = batch.n_tokens;
3790-
const uint32_t n_seqs = batch.n_seqs;
3791-
const uint32_t n_seq_tokens = batch.n_seq_tokens;
3788+
const struct llama_ubatch & ubatch) {
3789+
const uint32_t n_tokens = ubatch.n_tokens;
3790+
const uint32_t n_seqs = ubatch.n_seqs;
3791+
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
37923792

37933793
if (cache.recurrent) {
37943794
// For recurrent state architectures (like Mamba or RWKV),
37953795
// each cache cell can store the state for a whole sequence.
37963796
// A slot should be always be contiguous.
37973797

37983798
// can only process batches with an equal number of new tokens in each sequence
3799-
GGML_ASSERT(batch.equal_seqs);
3799+
GGML_ASSERT(ubatch.equal_seqs);
38003800

38013801
int32_t min = cache.size - 1;
38023802
int32_t max = 0;
38033803

38043804
// everything should fit if all seq_ids are smaller than the max
38053805
for (uint32_t s = 0; s < n_seqs; ++s) {
3806-
const uint32_t n_seq_id = batch.n_seq_id[s];
3806+
const uint32_t n_seq_id = ubatch.n_seq_id[s];
38073807
for (uint32_t j = 0; j < n_seq_id; ++j) {
3808-
const llama_seq_id seq_id = batch.seq_id[s][j];
3808+
const llama_seq_id seq_id = ubatch.seq_id[s][j];
38093809

38103810
if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
38113811
// too big seq_id
@@ -3864,7 +3864,7 @@ static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
38643864

38653865
// find usable cell range
38663866
for (uint32_t s = 0; s < n_seqs; ++s) {
3867-
const llama_seq_id seq_id = batch.seq_id[s][0];
3867+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
38683868
llama_kv_cell & seq_meta = cache.cells[seq_id];
38693869
bool has_cell = false;
38703870
if (seq_meta.tail >= 0) {
@@ -3903,7 +3903,7 @@ static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
39033903
// gather and re-order
39043904
for (uint32_t s = 0; s < n_seqs; ++s) {
39053905
int32_t dst_id = s + min;
3906-
int32_t src_id = cache.cells[batch.seq_id[s][0]].tail;
3906+
int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail;
39073907
if (dst_id != src_id) {
39083908
llama_kv_cell & dst_cell = cache.cells[dst_id];
39093909
llama_kv_cell & src_cell = cache.cells[src_id];
@@ -3924,20 +3924,20 @@ static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
39243924

39253925
// update the pos of the used seqs
39263926
for (uint32_t s = 0; s < n_seqs; ++s) {
3927-
const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1];
3927+
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
39283928
int32_t cell_id = s + min;
39293929
llama_kv_cell & cell = cache.cells[cell_id];
39303930

39313931
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
39323932
// What should happen when the pos backtracks or skips a value?
39333933
// Clearing the state mid-batch would require special-casing which isn't done.
39343934
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
3935-
__func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
3935+
__func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
39363936
}
39373937
cell.pos = last_pos;
39383938
cell.seq_id.clear();
3939-
for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) {
3940-
const llama_seq_id seq_id = batch.seq_id[s][j];
3939+
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
3940+
const llama_seq_id seq_id = ubatch.seq_id[s][j];
39413941
cell.seq_id.insert(seq_id);
39423942
cache.cells[seq_id].tail = cell_id;
39433943
}
@@ -3991,10 +3991,10 @@ static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
39913991
for (uint32_t s = 0; s < n_seqs; s++) {
39923992
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
39933993
uint32_t k = s*n_seq_tokens + i;
3994-
cache.cells[cache.head + k].pos = batch.pos[k];
3994+
cache.cells[cache.head + k].pos = ubatch.pos[k];
39953995

3996-
for (int32_t j = 0; j < batch.n_seq_id[s]; j++) {
3997-
cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]);
3996+
for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
3997+
cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]);
39983998
}
39993999
}
40004000
}
@@ -9931,21 +9931,21 @@ static struct ggml_tensor * llm_build_inp_embd(
99319931
struct ggml_context * ctx,
99329932
struct llama_context & lctx,
99339933
const llama_hparams & hparams,
9934-
const llama_ubatch & batch,
9934+
const llama_ubatch & ubatch,
99359935
struct ggml_tensor * tok_embd,
99369936
const llm_build_cb & cb) {
99379937
const int64_t n_embd = hparams.n_embd;
99389938

99399939
struct ggml_tensor * inpL;
99409940

9941-
if (batch.token) {
9942-
lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens);
9941+
if (ubatch.token) {
9942+
lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch.n_tokens);
99439943
cb(lctx.inp_tokens, "inp_tokens", -1);
99449944
ggml_set_input(lctx.inp_tokens);
99459945

99469946
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
99479947
} else {
9948-
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
9948+
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
99499949
inpL = lctx.inp_embd;
99509950
ggml_set_input(lctx.inp_embd);
99519951
}
@@ -10518,7 +10518,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
1051810518
static struct ggml_tensor * llm_build_mamba(
1051910519
struct ggml_context * ctx,
1052010520
struct llama_context & lctx,
10521-
const llama_ubatch & batch,
10521+
const llama_ubatch & ubatch,
1052210522
struct ggml_cgraph * graph,
1052310523
struct ggml_tensor * cur,
1052410524
struct ggml_tensor * state_copy,
@@ -10534,17 +10534,17 @@ static struct ggml_tensor * llm_build_mamba(
1053410534
const int64_t d_inner = hparams.ssm_d_inner;
1053510535
const int64_t d_state = hparams.ssm_d_state;
1053610536
const int64_t dt_rank = hparams.ssm_dt_rank;
10537-
const int64_t n_seqs = batch.n_seqs;
10537+
const int64_t n_seqs = ubatch.n_seqs;
1053810538
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
1053910539
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
1054010540
// Use the same RMS norm as the final layer norm
1054110541
const float norm_rms_eps = hparams.f_norm_rms_eps;
1054210542

10543-
const int64_t n_seq_tokens = batch.n_seq_tokens;
10543+
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
1054410544

1054510545
GGML_ASSERT(n_seqs != 0);
10546-
GGML_ASSERT(batch.equal_seqs);
10547-
GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs);
10546+
GGML_ASSERT(ubatch.equal_seqs);
10547+
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
1054810548

1054910549
struct ggml_tensor * conv_states_all = kv.k_l[il];
1055010550
struct ggml_tensor * ssm_states_all = kv.v_l[il];
@@ -21828,10 +21828,10 @@ struct llama_data_read {
2182821828

2182921829
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
2183021830

21831-
llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
21832-
batch.n_tokens = cell_count;
21833-
batch.n_seq_tokens = cell_count;
21834-
batch.n_seqs = 1;
21831+
llama_ubatch ubatch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
21832+
ubatch.n_tokens = cell_count;
21833+
ubatch.n_seq_tokens = cell_count;
21834+
ubatch.n_seqs = 1;
2183521835

2183621836
for (uint32_t i = 0; i < cell_count; ++i) {
2183721837
llama_pos pos;
@@ -21845,20 +21845,20 @@ struct llama_data_read {
2184521845
return false;
2184621846
}
2184721847

21848-
batch.pos[i] = pos;
21848+
ubatch.pos[i] = pos;
2184921849
}
21850-
batch.n_seq_id[0] = 1;
21851-
batch.seq_id[0] = &dest_seq_id;
21852-
if (!llama_kv_cache_find_slot(kv_self, batch)) {
21850+
ubatch.n_seq_id[0] = 1;
21851+
ubatch.seq_id[0] = &dest_seq_id;
21852+
if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
2185321853
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
2185421854
return false;
2185521855
}
2185621856

2185721857
// DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
2185821858
// Assume that this is one contiguous block of cells
2185921859
GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
21860-
GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
21861-
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
21860+
GGML_ASSERT(kv_self.cells[kv_self.head].pos == ubatch.pos[0]);
21861+
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == ubatch.pos[cell_count - 1]);
2186221862
GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
2186321863
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
2186421864
} else {

0 commit comments

Comments
 (0)