@@ -3785,27 +3785,27 @@ static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
3785
3785
// to the first cell of the slot.
3786
3786
static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
3787
3787
struct llama_kv_cache & cache,
3788
- const struct llama_ubatch & batch ) {
3789
- const uint32_t n_tokens = batch .n_tokens;
3790
- const uint32_t n_seqs = batch .n_seqs;
3791
- const uint32_t n_seq_tokens = batch .n_seq_tokens;
3788
+ const struct llama_ubatch & ubatch ) {
3789
+ const uint32_t n_tokens = ubatch .n_tokens;
3790
+ const uint32_t n_seqs = ubatch .n_seqs;
3791
+ const uint32_t n_seq_tokens = ubatch .n_seq_tokens;
3792
3792
3793
3793
if (cache.recurrent) {
3794
3794
// For recurrent state architectures (like Mamba or RWKV),
3795
3795
// each cache cell can store the state for a whole sequence.
3796
3796
// A slot should be always be contiguous.
3797
3797
3798
3798
// can only process batches with an equal number of new tokens in each sequence
3799
- GGML_ASSERT(batch .equal_seqs);
3799
+ GGML_ASSERT(ubatch .equal_seqs);
3800
3800
3801
3801
int32_t min = cache.size - 1;
3802
3802
int32_t max = 0;
3803
3803
3804
3804
// everything should fit if all seq_ids are smaller than the max
3805
3805
for (uint32_t s = 0; s < n_seqs; ++s) {
3806
- const uint32_t n_seq_id = batch .n_seq_id[s];
3806
+ const uint32_t n_seq_id = ubatch .n_seq_id[s];
3807
3807
for (uint32_t j = 0; j < n_seq_id; ++j) {
3808
- const llama_seq_id seq_id = batch .seq_id[s][j];
3808
+ const llama_seq_id seq_id = ubatch .seq_id[s][j];
3809
3809
3810
3810
if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
3811
3811
// too big seq_id
@@ -3864,7 +3864,7 @@ static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
3864
3864
3865
3865
// find usable cell range
3866
3866
for (uint32_t s = 0; s < n_seqs; ++s) {
3867
- const llama_seq_id seq_id = batch .seq_id[s][0];
3867
+ const llama_seq_id seq_id = ubatch .seq_id[s][0];
3868
3868
llama_kv_cell & seq_meta = cache.cells[seq_id];
3869
3869
bool has_cell = false;
3870
3870
if (seq_meta.tail >= 0) {
@@ -3903,7 +3903,7 @@ static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
3903
3903
// gather and re-order
3904
3904
for (uint32_t s = 0; s < n_seqs; ++s) {
3905
3905
int32_t dst_id = s + min;
3906
- int32_t src_id = cache.cells[batch .seq_id[s][0]].tail;
3906
+ int32_t src_id = cache.cells[ubatch .seq_id[s][0]].tail;
3907
3907
if (dst_id != src_id) {
3908
3908
llama_kv_cell & dst_cell = cache.cells[dst_id];
3909
3909
llama_kv_cell & src_cell = cache.cells[src_id];
@@ -3924,20 +3924,20 @@ static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
3924
3924
3925
3925
// update the pos of the used seqs
3926
3926
for (uint32_t s = 0; s < n_seqs; ++s) {
3927
- const llama_pos last_pos = batch .pos[n_seq_tokens * s + n_seq_tokens - 1];
3927
+ const llama_pos last_pos = ubatch .pos[n_seq_tokens * s + n_seq_tokens - 1];
3928
3928
int32_t cell_id = s + min;
3929
3929
llama_kv_cell & cell = cache.cells[cell_id];
3930
3930
3931
3931
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
3932
3932
// What should happen when the pos backtracks or skips a value?
3933
3933
// Clearing the state mid-batch would require special-casing which isn't done.
3934
3934
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
3935
- __func__, last_pos, cell.pos, batch .seq_id[s][0], n_seq_tokens);
3935
+ __func__, last_pos, cell.pos, ubatch .seq_id[s][0], n_seq_tokens);
3936
3936
}
3937
3937
cell.pos = last_pos;
3938
3938
cell.seq_id.clear();
3939
- for (int32_t j = 0; j < batch .n_seq_id[s]; ++j) {
3940
- const llama_seq_id seq_id = batch .seq_id[s][j];
3939
+ for (int32_t j = 0; j < ubatch .n_seq_id[s]; ++j) {
3940
+ const llama_seq_id seq_id = ubatch .seq_id[s][j];
3941
3941
cell.seq_id.insert(seq_id);
3942
3942
cache.cells[seq_id].tail = cell_id;
3943
3943
}
@@ -3991,10 +3991,10 @@ static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
3991
3991
for (uint32_t s = 0; s < n_seqs; s++) {
3992
3992
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
3993
3993
uint32_t k = s*n_seq_tokens + i;
3994
- cache.cells[cache.head + k].pos = batch .pos[k];
3994
+ cache.cells[cache.head + k].pos = ubatch .pos[k];
3995
3995
3996
- for (int32_t j = 0; j < batch .n_seq_id[s]; j++) {
3997
- cache.cells[cache.head + k].seq_id.insert(batch .seq_id[s][j]);
3996
+ for (int32_t j = 0; j < ubatch .n_seq_id[s]; j++) {
3997
+ cache.cells[cache.head + k].seq_id.insert(ubatch .seq_id[s][j]);
3998
3998
}
3999
3999
}
4000
4000
}
@@ -9931,21 +9931,21 @@ static struct ggml_tensor * llm_build_inp_embd(
9931
9931
struct ggml_context * ctx,
9932
9932
struct llama_context & lctx,
9933
9933
const llama_hparams & hparams,
9934
- const llama_ubatch & batch ,
9934
+ const llama_ubatch & ubatch ,
9935
9935
struct ggml_tensor * tok_embd,
9936
9936
const llm_build_cb & cb) {
9937
9937
const int64_t n_embd = hparams.n_embd;
9938
9938
9939
9939
struct ggml_tensor * inpL;
9940
9940
9941
- if (batch .token) {
9942
- lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch .n_tokens);
9941
+ if (ubatch .token) {
9942
+ lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch .n_tokens);
9943
9943
cb(lctx.inp_tokens, "inp_tokens", -1);
9944
9944
ggml_set_input(lctx.inp_tokens);
9945
9945
9946
9946
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
9947
9947
} else {
9948
- lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch .n_tokens);
9948
+ lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch .n_tokens);
9949
9949
inpL = lctx.inp_embd;
9950
9950
ggml_set_input(lctx.inp_embd);
9951
9951
}
@@ -10518,7 +10518,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
10518
10518
static struct ggml_tensor * llm_build_mamba(
10519
10519
struct ggml_context * ctx,
10520
10520
struct llama_context & lctx,
10521
- const llama_ubatch & batch ,
10521
+ const llama_ubatch & ubatch ,
10522
10522
struct ggml_cgraph * graph,
10523
10523
struct ggml_tensor * cur,
10524
10524
struct ggml_tensor * state_copy,
@@ -10534,17 +10534,17 @@ static struct ggml_tensor * llm_build_mamba(
10534
10534
const int64_t d_inner = hparams.ssm_d_inner;
10535
10535
const int64_t d_state = hparams.ssm_d_state;
10536
10536
const int64_t dt_rank = hparams.ssm_dt_rank;
10537
- const int64_t n_seqs = batch .n_seqs;
10537
+ const int64_t n_seqs = ubatch .n_seqs;
10538
10538
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
10539
10539
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
10540
10540
// Use the same RMS norm as the final layer norm
10541
10541
const float norm_rms_eps = hparams.f_norm_rms_eps;
10542
10542
10543
- const int64_t n_seq_tokens = batch .n_seq_tokens;
10543
+ const int64_t n_seq_tokens = ubatch .n_seq_tokens;
10544
10544
10545
10545
GGML_ASSERT(n_seqs != 0);
10546
- GGML_ASSERT(batch .equal_seqs);
10547
- GGML_ASSERT(batch .n_tokens == n_seq_tokens * n_seqs);
10546
+ GGML_ASSERT(ubatch .equal_seqs);
10547
+ GGML_ASSERT(ubatch .n_tokens == n_seq_tokens * n_seqs);
10548
10548
10549
10549
struct ggml_tensor * conv_states_all = kv.k_l[il];
10550
10550
struct ggml_tensor * ssm_states_all = kv.v_l[il];
@@ -21828,10 +21828,10 @@ struct llama_data_read {
21828
21828
21829
21829
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
21830
21830
21831
- llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
21832
- batch .n_tokens = cell_count;
21833
- batch .n_seq_tokens = cell_count;
21834
- batch .n_seqs = 1;
21831
+ llama_ubatch ubatch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
21832
+ ubatch .n_tokens = cell_count;
21833
+ ubatch .n_seq_tokens = cell_count;
21834
+ ubatch .n_seqs = 1;
21835
21835
21836
21836
for (uint32_t i = 0; i < cell_count; ++i) {
21837
21837
llama_pos pos;
@@ -21845,20 +21845,20 @@ struct llama_data_read {
21845
21845
return false;
21846
21846
}
21847
21847
21848
- batch .pos[i] = pos;
21848
+ ubatch .pos[i] = pos;
21849
21849
}
21850
- batch .n_seq_id[0] = 1;
21851
- batch .seq_id[0] = &dest_seq_id;
21852
- if (!llama_kv_cache_find_slot(kv_self, batch )) {
21850
+ ubatch .n_seq_id[0] = 1;
21851
+ ubatch .seq_id[0] = &dest_seq_id;
21852
+ if (!llama_kv_cache_find_slot(kv_self, ubatch )) {
21853
21853
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
21854
21854
return false;
21855
21855
}
21856
21856
21857
21857
// DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
21858
21858
// Assume that this is one contiguous block of cells
21859
21859
GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
21860
- GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch .pos[0]);
21861
- GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch .pos[cell_count - 1]);
21860
+ GGML_ASSERT(kv_self.cells[kv_self.head].pos == ubatch .pos[0]);
21861
+ GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == ubatch .pos[cell_count - 1]);
21862
21862
GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
21863
21863
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
21864
21864
} else {
0 commit comments