@@ -3591,27 +3591,27 @@ static bool llama_kv_cache_init(
3591
3591
// to the first cell of the slot.
3592
3592
static bool llama_kv_cache_find_slot(
3593
3593
struct llama_kv_cache & cache,
3594
- const struct llama_ubatch & batch ) {
3595
- const uint32_t n_tokens = batch .n_tokens;
3596
- const uint32_t n_seqs = batch .n_seqs;
3597
- const uint32_t n_seq_tokens = batch .n_seq_tokens;
3594
+ const struct llama_ubatch & ubatch ) {
3595
+ const uint32_t n_tokens = ubatch .n_tokens;
3596
+ const uint32_t n_seqs = ubatch .n_seqs;
3597
+ const uint32_t n_seq_tokens = ubatch .n_seq_tokens;
3598
3598
3599
3599
if (cache.recurrent) {
3600
3600
// For recurrent state architectures (like Mamba or RWKV),
3601
3601
// each cache cell can store the state for a whole sequence.
3602
3602
// A slot should be always be contiguous.
3603
3603
3604
3604
// can only process batches with an equal number of new tokens in each sequence
3605
- GGML_ASSERT(batch .equal_seqs);
3605
+ GGML_ASSERT(ubatch .equal_seqs);
3606
3606
3607
3607
int32_t min = cache.size - 1;
3608
3608
int32_t max = 0;
3609
3609
3610
3610
// everything should fit if all seq_ids are smaller than the max
3611
3611
for (uint32_t s = 0; s < n_seqs; ++s) {
3612
- const uint32_t n_seq_id = batch .n_seq_id[s];
3612
+ const uint32_t n_seq_id = ubatch .n_seq_id[s];
3613
3613
for (uint32_t j = 0; j < n_seq_id; ++j) {
3614
- const llama_seq_id seq_id = batch .seq_id[s][j];
3614
+ const llama_seq_id seq_id = ubatch .seq_id[s][j];
3615
3615
3616
3616
if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
3617
3617
// too big seq_id
@@ -3670,7 +3670,7 @@ static bool llama_kv_cache_find_slot(
3670
3670
3671
3671
// find usable cell range
3672
3672
for (uint32_t s = 0; s < n_seqs; ++s) {
3673
- const llama_seq_id seq_id = batch .seq_id[s][0];
3673
+ const llama_seq_id seq_id = ubatch .seq_id[s][0];
3674
3674
llama_kv_cell & seq_meta = cache.cells[seq_id];
3675
3675
bool has_cell = false;
3676
3676
if (seq_meta.tail >= 0) {
@@ -3709,7 +3709,7 @@ static bool llama_kv_cache_find_slot(
3709
3709
// gather and re-order
3710
3710
for (uint32_t s = 0; s < n_seqs; ++s) {
3711
3711
int32_t dst_id = s + min;
3712
- int32_t src_id = cache.cells[batch .seq_id[s][0]].tail;
3712
+ int32_t src_id = cache.cells[ubatch .seq_id[s][0]].tail;
3713
3713
if (dst_id != src_id) {
3714
3714
llama_kv_cell & dst_cell = cache.cells[dst_id];
3715
3715
llama_kv_cell & src_cell = cache.cells[src_id];
@@ -3730,20 +3730,20 @@ static bool llama_kv_cache_find_slot(
3730
3730
3731
3731
// update the pos of the used seqs
3732
3732
for (uint32_t s = 0; s < n_seqs; ++s) {
3733
- const llama_pos last_pos = batch .pos[n_seq_tokens * s + n_seq_tokens - 1];
3733
+ const llama_pos last_pos = ubatch .pos[n_seq_tokens * s + n_seq_tokens - 1];
3734
3734
int32_t cell_id = s + min;
3735
3735
llama_kv_cell & cell = cache.cells[cell_id];
3736
3736
3737
3737
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
3738
3738
// What should happen when the pos backtracks or skips a value?
3739
3739
// Clearing the state mid-batch would require special-casing which isn't done.
3740
3740
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
3741
- __func__, last_pos, cell.pos, batch .seq_id[s][0], n_seq_tokens);
3741
+ __func__, last_pos, cell.pos, ubatch .seq_id[s][0], n_seq_tokens);
3742
3742
}
3743
3743
cell.pos = last_pos;
3744
3744
cell.seq_id.clear();
3745
- for (int32_t j = 0; j < batch .n_seq_id[s]; ++j) {
3746
- const llama_seq_id seq_id = batch .seq_id[s][j];
3745
+ for (int32_t j = 0; j < ubatch .n_seq_id[s]; ++j) {
3746
+ const llama_seq_id seq_id = ubatch .seq_id[s][j];
3747
3747
cell.seq_id.insert(seq_id);
3748
3748
cache.cells[seq_id].tail = cell_id;
3749
3749
}
@@ -3795,10 +3795,10 @@ static bool llama_kv_cache_find_slot(
3795
3795
for (uint32_t s = 0; s < n_seqs; s++) {
3796
3796
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
3797
3797
uint32_t k = s*n_seq_tokens + i;
3798
- cache.cells[cache.head + k].pos = batch .pos[k];
3798
+ cache.cells[cache.head + k].pos = ubatch .pos[k];
3799
3799
3800
- for (int32_t j = 0; j < batch .n_seq_id[s]; j++) {
3801
- cache.cells[cache.head + k].seq_id.insert(batch .seq_id[s][j]);
3800
+ for (int32_t j = 0; j < ubatch .n_seq_id[s]; j++) {
3801
+ cache.cells[cache.head + k].seq_id.insert(ubatch .seq_id[s][j]);
3802
3802
}
3803
3803
}
3804
3804
}
@@ -9178,21 +9178,21 @@ static struct ggml_tensor * llm_build_inp_embd(
9178
9178
struct ggml_context * ctx,
9179
9179
struct llama_context & lctx,
9180
9180
const llama_hparams & hparams,
9181
- const llama_ubatch & batch ,
9181
+ const llama_ubatch & ubatch ,
9182
9182
struct ggml_tensor * tok_embd,
9183
9183
const llm_build_cb & cb) {
9184
9184
const int64_t n_embd = hparams.n_embd;
9185
9185
9186
9186
struct ggml_tensor * inpL;
9187
9187
9188
- if (batch .token) {
9189
- lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch .n_tokens);
9188
+ if (ubatch .token) {
9189
+ lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch .n_tokens);
9190
9190
cb(lctx.inp_tokens, "inp_tokens", -1);
9191
9191
ggml_set_input(lctx.inp_tokens);
9192
9192
9193
9193
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
9194
9194
} else {
9195
- lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch .n_tokens);
9195
+ lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch .n_tokens);
9196
9196
inpL = lctx.inp_embd;
9197
9197
ggml_set_input(lctx.inp_embd);
9198
9198
}
@@ -9766,7 +9766,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
9766
9766
static struct ggml_tensor * llm_build_mamba(
9767
9767
struct ggml_context * ctx,
9768
9768
struct llama_context & lctx,
9769
- const llama_ubatch & batch ,
9769
+ const llama_ubatch & ubatch ,
9770
9770
struct ggml_cgraph * graph,
9771
9771
struct ggml_tensor * cur,
9772
9772
struct ggml_tensor * state_copy,
@@ -9782,17 +9782,17 @@ static struct ggml_tensor * llm_build_mamba(
9782
9782
const int64_t d_inner = hparams.ssm_d_inner;
9783
9783
const int64_t d_state = hparams.ssm_d_state;
9784
9784
const int64_t dt_rank = hparams.ssm_dt_rank;
9785
- const int64_t n_seqs = batch .n_seqs;
9785
+ const int64_t n_seqs = ubatch .n_seqs;
9786
9786
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
9787
9787
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
9788
9788
// Use the same RMS norm as the final layer norm
9789
9789
const float norm_rms_eps = hparams.f_norm_rms_eps;
9790
9790
9791
- const int64_t n_seq_tokens = batch .n_seq_tokens;
9791
+ const int64_t n_seq_tokens = ubatch .n_seq_tokens;
9792
9792
9793
9793
GGML_ASSERT(n_seqs != 0);
9794
- GGML_ASSERT(batch .equal_seqs);
9795
- GGML_ASSERT(batch .n_tokens == n_seq_tokens * n_seqs);
9794
+ GGML_ASSERT(ubatch .equal_seqs);
9795
+ GGML_ASSERT(ubatch .n_tokens == n_seq_tokens * n_seqs);
9796
9796
9797
9797
struct ggml_tensor * conv_states_all = kv.k_l[il];
9798
9798
struct ggml_tensor * ssm_states_all = kv.v_l[il];
@@ -20440,10 +20440,10 @@ struct llama_data_read {
20440
20440
20441
20441
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
20442
20442
20443
- llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
20444
- batch .n_tokens = cell_count;
20445
- batch .n_seq_tokens = cell_count;
20446
- batch .n_seqs = 1;
20443
+ llama_ubatch ubatch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
20444
+ ubatch .n_tokens = cell_count;
20445
+ ubatch .n_seq_tokens = cell_count;
20446
+ ubatch .n_seqs = 1;
20447
20447
20448
20448
for (uint32_t i = 0; i < cell_count; ++i) {
20449
20449
llama_pos pos;
@@ -20457,20 +20457,20 @@ struct llama_data_read {
20457
20457
return false;
20458
20458
}
20459
20459
20460
- batch .pos[i] = pos;
20460
+ ubatch .pos[i] = pos;
20461
20461
}
20462
- batch .n_seq_id[0] = 1;
20463
- batch .seq_id[0] = &dest_seq_id;
20464
- if (!llama_kv_cache_find_slot(kv_self, batch )) {
20462
+ ubatch .n_seq_id[0] = 1;
20463
+ ubatch .seq_id[0] = &dest_seq_id;
20464
+ if (!llama_kv_cache_find_slot(kv_self, ubatch )) {
20465
20465
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
20466
20466
return false;
20467
20467
}
20468
20468
20469
20469
// DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
20470
20470
// Assume that this is one contiguous block of cells
20471
20471
GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
20472
- GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch .pos[0]);
20473
- GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch .pos[cell_count - 1]);
20472
+ GGML_ASSERT(kv_self.cells[kv_self.head].pos == ubatch .pos[0]);
20473
+ GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == ubatch .pos[cell_count - 1]);
20474
20474
GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
20475
20475
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
20476
20476
} else {
0 commit comments