6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
- #include < algorithm>
10
- #include < fstream>
11
-
12
9
#include < executorch/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h>
13
10
#include < executorch/runtime/core/exec_aten/util/tensor_util.h>
11
+ #include < algorithm>
14
12
15
13
using executorch::aten::Tensor;
16
14
using executorch::aten::TensorImpl;
@@ -55,7 +53,8 @@ std::vector<Tensor> Memory::get_output_tensors(
55
53
56
54
HybridMemory::HybridMemory (
57
55
std::vector<std::shared_ptr<Module>>& modules,
58
- int32_t max_seq_len,
56
+ int32_t prefill_cache_len,
57
+ int32_t kv_cache_len,
59
58
int32_t vocab_size,
60
59
int32_t num_layers,
61
60
int32_t head_dim,
@@ -65,7 +64,8 @@ HybridMemory::HybridMemory(
65
64
const std::string& kv_forward_name)
66
65
: Memory(modules),
67
66
shard_layers_ ({num_layers}),
68
- max_seq_len_(max_seq_len),
67
+ prefill_cache_len_(prefill_cache_len),
68
+ kv_cache_len_(kv_cache_len),
69
69
vocab_size_(vocab_size),
70
70
num_layers_(num_layers),
71
71
head_dim_(head_dim),
@@ -106,17 +106,17 @@ HybridMemory::HybridMemory(
106
106
new IO, [](void * ptr) { delete static_cast <IO*>(ptr); });
107
107
}
108
108
109
- void HybridMemory::init_io (
110
- const std::vector<executorch::runtime::Result<
111
- executorch::runtime::MethodMeta>>& methods_meta,
112
- EvalMode eval_mode) {
109
+ void HybridMemory::init_io () {
113
110
IO* ptr = static_cast <IO*>(data_ptr_.get ());
114
111
std::memset (ptr, 0 , sizeof (IO));
115
112
116
- int32_t cache_len = max_seq_len_ - 1 ;
117
- int32_t k_in_size = (head_dim_ + 1 ) * (max_seq_len_ - 1 );
118
- int32_t k_cache_out_size = num_heads_ * head_dim_ * cache_len;
119
- int32_t v_cache_size = (num_heads_ + 1 ) * (max_seq_len_ - 1 ) * head_dim_;
113
+ int32_t max_cache_len = std::max (kv_cache_len_, prefill_cache_len_);
114
+ int32_t k_in_size = (head_dim_ + 1 ) * max_cache_len;
115
+ int32_t v_cache_size = (num_heads_ + 1 ) * max_cache_len * head_dim_;
116
+ int32_t k_cache_out_size = num_heads_ * head_dim_;
117
+ if (eval_mode_ == EvalMode::kHybrid || eval_mode_ == EvalMode::kPrefill ) {
118
+ k_cache_out_size *= prefill_cache_len_;
119
+ }
120
120
121
121
// Init kv vector shape, general enough to be shared across all 3 modes.
122
122
ptr->k_cache_out .reserve (num_layers_);
@@ -127,14 +127,14 @@ void HybridMemory::init_io(
127
127
}
128
128
129
129
auto init_prefill = [&]() {
130
- ptr->prefill_input_toks .resize (cache_len );
131
- ptr->prefill_atten_mask .resize (cache_len * cache_len );
132
- ptr->prefill_logits .resize (cache_len * vocab_size_);
130
+ ptr->prefill_input_toks .resize (prefill_cache_len_ );
131
+ ptr->prefill_atten_mask .resize (prefill_cache_len_ * prefill_cache_len_ );
132
+ ptr->prefill_logits .resize (prefill_cache_len_ * vocab_size_);
133
133
};
134
134
135
135
auto init_kv = [&]() {
136
136
ptr->kv_logits .resize (vocab_size_);
137
- ptr->kv_attention_mask .resize (max_seq_len_ , -255 );
137
+ ptr->kv_attention_mask .resize ((kv_cache_len_ + 1 ) , -255 );
138
138
ptr->k_cache .reserve (num_layers_);
139
139
for (int layer = 0 ; layer < num_layers_; layer++) {
140
140
ptr->k_cache .emplace_back ();
@@ -145,7 +145,7 @@ void HybridMemory::init_io(
145
145
}
146
146
};
147
147
148
- switch (eval_mode ) {
148
+ switch (eval_mode_ ) {
149
149
case EvalMode::kPrefill :
150
150
init_prefill ();
151
151
break ;
@@ -205,9 +205,7 @@ void HybridMemory::prepare_kv_io(
205
205
206
206
// [I] kv_cache
207
207
int index = 3 ; // bypass input_tokens, input_pos, atten_mask
208
- for (int offset = 0 ,
209
- shard_index = 0 ,
210
- v_stride = (max_seq_len_ - 1 ) * head_dim_;
208
+ for (int offset = 0 , shard_index = 0 , v_stride = kv_cache_len_ * head_dim_;
211
209
shard_index < modules_.size ();
212
210
offset += shard_layers_[shard_index], shard_index++) {
213
211
for (int cache_group = 0 ; cache_group < 2 ; ++cache_group) {
@@ -256,9 +254,7 @@ void HybridMemory::prepare_kv_io(
256
254
// For k, we store it in k_cache_out and update to k_cache later.
257
255
// For v, we append the output to the end of v_cache,
258
256
// which serves as both input and output.
259
- for (int offset = 0 ,
260
- shard_index = 0 ,
261
- v_stride = (max_seq_len_ - 1 ) * head_dim_;
257
+ for (int offset = 0 , shard_index = 0 , v_stride = kv_cache_len_ * head_dim_;
262
258
shard_index < modules_.size ();
263
259
offset += shard_layers_[shard_index], shard_index++) {
264
260
for (int cache_group = 0 ; cache_group < 2 ; ++cache_group) {
@@ -305,8 +301,6 @@ void HybridMemory::prepare_prefill_io(
305
301
306
302
IO* ptr = static_cast <IO*>(data_ptr_.get ());
307
303
308
- // cache_len should be max_seq_len - 1
309
- int32_t cache_len = methods_meta[0 ]->input_tensor_meta (0 )->sizes ()[1 ];
310
304
// [I]: pre_input_tokens
311
305
Result<TensorInfo> prefill_input_toks = methods_meta[0 ]->input_tensor_meta (0 );
312
306
prefill_input_toks_ = std::make_unique<TensorImpl>(
@@ -318,12 +312,12 @@ void HybridMemory::prepare_prefill_io(
318
312
prefill_input_toks->dim_order ().data ()));
319
313
input_tensors_[prefill_forward_name_][0 ].push_back (prefill_input_toks_.get ());
320
314
// [I]: prefill_attn_mask
321
- for (int i = 0 ; i < cache_len ; ++i) {
322
- for (int j = 0 ; j < cache_len ; ++j) {
315
+ for (int i = 0 ; i < prefill_cache_len_ ; ++i) {
316
+ for (int j = 0 ; j < prefill_cache_len_ ; ++j) {
323
317
if (i < j) {
324
- ptr->prefill_atten_mask [i * cache_len + j] = -255 ;
318
+ ptr->prefill_atten_mask [i * prefill_cache_len_ + j] = -255 ;
325
319
} else {
326
- ptr->prefill_atten_mask [i * cache_len + j] = 0 ;
320
+ ptr->prefill_atten_mask [i * prefill_cache_len_ + j] = 0 ;
327
321
}
328
322
}
329
323
}
@@ -347,10 +341,22 @@ void HybridMemory::prepare_prefill_io(
347
341
const_cast <TensorImpl::DimOrderType*>(logits->dim_order ().data ()));
348
342
output_tensors_[prefill_forward_name_][modules_.size () - 1 ].push_back (
349
343
prefill_logits_.get ());
344
+
350
345
// [O] kv_cache
351
346
int index = 1 ;
352
- for (int offset = 0 , shard_index = 0 , cache_stride = cache_len * head_dim_;
353
- shard_index < modules_.size ();
347
+ // prefill_k_stride should be equal to prefill_v_stride in prefill mode.
348
+ // In hybrid mode, we use kv mode cache len for v stride since we want to
349
+ // update prefill's result onto kv modes input.
350
+ int32_t prefill_k_stride = prefill_cache_len_ * head_dim_;
351
+ int32_t prefill_v_stride =
352
+ std::max (prefill_cache_len_, kv_cache_len_) * head_dim_;
353
+
354
+ if (eval_mode_ == EvalMode::kPrefill ) {
355
+ ET_CHECK_MSG (
356
+ prefill_k_stride == prefill_v_stride,
357
+ " prefill_k_stride should be equal to prefill_v_stride" );
358
+ }
359
+ for (int offset = 0 , shard_index = 0 ; shard_index < modules_.size ();
354
360
offset += shard_layers_[shard_index], shard_index++) {
355
361
for (int cache_group = 0 ; cache_group < 2 ; ++cache_group) {
356
362
for (int layer = 0 ; layer < shard_layers_[shard_index]; ++layer) {
@@ -363,10 +369,10 @@ void HybridMemory::prepare_prefill_io(
363
369
void * cache_ptr = (cache_group == 0 )
364
370
? static_cast <void *>(
365
371
ptr->k_cache_out [layer + offset].data () +
366
- head * cache_stride )
372
+ head * prefill_k_stride )
367
373
: static_cast <void *>(
368
374
ptr->v_cache [layer + offset].data () +
369
- (head + 1 ) * cache_stride );
375
+ (head + 1 ) * prefill_v_stride );
370
376
cache.emplace_back (std::make_unique<TensorImpl>(
371
377
kv_cache->scalar_type (),
372
378
kv_cache->sizes ().size (),
@@ -386,15 +392,17 @@ void HybridMemory::update_prefill_to_kv_io(
386
392
int64_t cur_token,
387
393
int64_t pos,
388
394
std::vector<std::vector<Tensor>>& output_tensors) {
389
- int cache_len = (max_seq_len_ - 1 );
395
+ ET_CHECK_MSG (kv_cache_len_ != 0 , " k_cache_len_ should not equal to 0" );
396
+ ET_CHECK_MSG (
397
+ prefill_cache_len_ != 0 , " prefill_cache_len_ should not equal to 0" );
390
398
IO* ptr = static_cast <IO*>(data_ptr_.get ());
391
399
392
400
ptr->input_tok = static_cast <int32_t >(cur_token);
393
401
ptr->input_pos = static_cast <int32_t >(pos);
394
402
// If prompt len is 30, prefill will handle to pos = 30.
395
403
// At this point, pos should be 31.
396
404
for (int i = 0 ; i < pos + 1 ; i++) {
397
- ptr->kv_attention_mask [cache_len - i] = 0 ;
405
+ ptr->kv_attention_mask [kv_cache_len_ - i] = 0 ;
398
406
}
399
407
400
408
// update v_cache
@@ -429,9 +437,9 @@ void HybridMemory::update_prefill_to_kv_io(
429
437
for (int i = 0 ; i < k_cache_in.size (); ++i) {
430
438
uint8_t * ptr_in = k_cache_in[i]->mutable_data <uint8_t >();
431
439
const uint8_t * ptr_out = k_cache_out[i]->data <uint8_t >();
432
- for (size_t j = 0 , offset = cache_len ; j < head_dim_;
433
- ++j, offset += cache_len ) {
434
- for (int k = 0 , k_stride = j * cache_len ; k < pos; k++) {
440
+ for (size_t j = 0 , offset = kv_cache_len_ ; j < head_dim_;
441
+ ++j, offset += kv_cache_len_ ) {
442
+ for (int k = 0 , k_stride = j * prefill_cache_len_ ; k < pos; k++) {
435
443
ptr_in[offset + k] = ptr_out[k_stride + k];
436
444
}
437
445
}
@@ -444,13 +452,12 @@ void HybridMemory::update_kv_io(
444
452
int64_t pos,
445
453
std::vector<std::vector<Tensor>>& output_tensors) {
446
454
IO* ptr = static_cast <IO*>(data_ptr_.get ());
447
- int seq_len = (max_seq_len_ - 1 );
448
455
// update input_tok
449
456
ptr->input_tok = static_cast <int32_t >(cur_token);
450
457
// update position_ids
451
458
ptr->input_pos = static_cast <int32_t >(pos);
452
459
// update causal mask for next token
453
- ptr->kv_attention_mask [seq_len - pos] = 0 ;
460
+ ptr->kv_attention_mask [kv_cache_len_ - pos] = 0 ;
454
461
455
462
// update v_cache
456
463
auto & v_cache_in = v_cache_in_[kv_forward_name_];
@@ -480,8 +487,8 @@ void HybridMemory::update_kv_io(
480
487
for (int i = 0 ; i < k_cache_in.size (); ++i) {
481
488
uint8_t * ptr_in = k_cache_in[i]->mutable_data <uint8_t >();
482
489
const uint8_t * ptr_out = k_cache_out[i]->data <uint8_t >();
483
- for (size_t j = 0 , offset = seq_len ; j < head_dim_;
484
- ++j, offset += seq_len ) {
490
+ for (size_t j = 0 , offset = kv_cache_len_ ; j < head_dim_;
491
+ ++j, offset += kv_cache_len_ ) {
485
492
ptr_in[offset] = ptr_out[j];
486
493
}
487
494
k_cache_in[i]->set_data (ptr_in + 1 );
0 commit comments