@@ -119,27 +119,27 @@ bool llama_kv_cache_init(
119
119
120
120
struct llama_kv_cache_slot_info llama_kv_cache_find_slot (
121
121
struct llama_kv_cache & cache,
122
- const struct llama_ubatch & batch ) {
123
- const uint32_t n_tokens = batch .n_tokens ;
124
- const uint32_t n_seqs = batch .n_seqs ;
125
- const uint32_t n_seq_tokens = batch .n_seq_tokens ;
122
+ const struct llama_ubatch & ubatch ) {
123
+ const uint32_t n_tokens = ubatch .n_tokens ;
124
+ const uint32_t n_seqs = ubatch .n_seqs ;
125
+ const uint32_t n_seq_tokens = ubatch .n_seq_tokens ;
126
126
127
127
if (cache.recurrent ) {
128
128
// For recurrent state architectures (like Mamba or RWKV),
129
129
// each cache cell can store the state for a whole sequence.
130
130
// A slot should be always be contiguous.
131
131
132
132
// can only process batches with an equal number of new tokens in each sequence
133
- GGML_ASSERT (batch .equal_seqs );
133
+ GGML_ASSERT (ubatch .equal_seqs );
134
134
135
135
int32_t min = cache.size - 1 ;
136
136
int32_t max = 0 ;
137
137
138
138
// everything should fit if all seq_ids are smaller than the max
139
139
for (uint32_t s = 0 ; s < n_seqs; ++s) {
140
- const uint32_t n_seq_id = batch .n_seq_id [s];
140
+ const uint32_t n_seq_id = ubatch .n_seq_id [s];
141
141
for (uint32_t j = 0 ; j < n_seq_id; ++j) {
142
- const llama_seq_id seq_id = batch .seq_id [s][j];
142
+ const llama_seq_id seq_id = ubatch .seq_id [s][j];
143
143
144
144
if (seq_id < 0 || (uint32_t ) seq_id >= cache.size ) {
145
145
// too big seq_id
@@ -198,7 +198,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
198
198
199
199
// find usable cell range
200
200
for (uint32_t s = 0 ; s < n_seqs; ++s) {
201
- const llama_seq_id seq_id = batch .seq_id [s][0 ];
201
+ const llama_seq_id seq_id = ubatch .seq_id [s][0 ];
202
202
llama_kv_cell & seq_meta = cache.cells [seq_id];
203
203
bool has_cell = false ;
204
204
if (seq_meta.tail >= 0 ) {
@@ -237,7 +237,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
237
237
// gather and re-order
238
238
for (uint32_t s = 0 ; s < n_seqs; ++s) {
239
239
int32_t dst_id = s + min;
240
- int32_t src_id = cache.cells [batch .seq_id [s][0 ]].tail ;
240
+ int32_t src_id = cache.cells [ubatch .seq_id [s][0 ]].tail ;
241
241
if (dst_id != src_id) {
242
242
llama_kv_cell & dst_cell = cache.cells [dst_id];
243
243
llama_kv_cell & src_cell = cache.cells [src_id];
@@ -258,20 +258,20 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
258
258
259
259
// update the pos of the used seqs
260
260
for (uint32_t s = 0 ; s < n_seqs; ++s) {
261
- const llama_pos last_pos = batch .pos [n_seq_tokens * s + n_seq_tokens - 1 ];
261
+ const llama_pos last_pos = ubatch .pos [n_seq_tokens * s + n_seq_tokens - 1 ];
262
262
int32_t cell_id = s + min;
263
263
llama_kv_cell & cell = cache.cells [cell_id];
264
264
265
265
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
266
266
// What should happen when the pos backtracks or skips a value?
267
267
// Clearing the state mid-batch would require special-casing which isn't done.
268
268
LLAMA_LOG_WARN (" %s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n " ,
269
- __func__, last_pos, cell.pos , batch .seq_id [s][0 ], n_seq_tokens);
269
+ __func__, last_pos, cell.pos , ubatch .seq_id [s][0 ], n_seq_tokens);
270
270
}
271
271
cell.pos = last_pos;
272
272
cell.seq_id .clear ();
273
- for (int32_t j = 0 ; j < batch .n_seq_id [s]; ++j) {
274
- const llama_seq_id seq_id = batch .seq_id [s][j];
273
+ for (int32_t j = 0 ; j < ubatch .n_seq_id [s]; ++j) {
274
+ const llama_seq_id seq_id = ubatch .seq_id [s][j];
275
275
cell.seq_id .insert (seq_id);
276
276
cache.cells [seq_id].tail = cell_id;
277
277
}
@@ -325,10 +325,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
325
325
for (uint32_t s = 0 ; s < n_seqs; s++) {
326
326
for (uint32_t i = 0 ; i < n_seq_tokens; ++i) {
327
327
uint32_t k = s*n_seq_tokens + i;
328
- cache.cells [cache.head + k].pos = batch .pos [k];
328
+ cache.cells [cache.head + k].pos = ubatch .pos [k];
329
329
330
- for (int32_t j = 0 ; j < batch .n_seq_id [s]; j++) {
331
- cache.cells [cache.head + k].seq_id .insert (batch .seq_id [s][j]);
330
+ for (int32_t j = 0 ; j < ubatch .n_seq_id [s]; j++) {
331
+ cache.cells [cache.head + k].seq_id .insert (ubatch .seq_id [s][j]);
332
332
}
333
333
}
334
334
}
0 commit comments