@@ -541,41 +541,34 @@ bool validate_flash_attention_args(
541
541
bool validate_cache_params (
542
542
const Tensor& k_cache,
543
543
const Tensor& v_cache,
544
- int64_t layer_id,
545
544
int64_t start_pos,
546
545
int64_t seq_length) {
547
546
ET_LOG_MSG_AND_RETURN_IF_FALSE (
548
- k_cache.dim () == 5 , " kcache must be a 5D tensor" );
547
+ k_cache.dim () == 4 , " kcache must be a 4D tensor" );
549
548
550
549
ET_LOG_MSG_AND_RETURN_IF_FALSE (
551
- v_cache.dim () == 5 , " v_cache must be a 5D tensor" );
550
+ v_cache.dim () == 4 , " v_cache must be a 4D tensor" );
552
551
553
552
ET_LOG_MSG_AND_RETURN_IF_FALSE (
554
- layer_id < k_cache.size (0 ), " layer_id must be less than kcache dim 0" );
555
-
556
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
557
- layer_id < v_cache.size (0 ), " layer_id must be less than vcache dim 0" );
558
-
559
- ET_LOG_MSG_AND_RETURN_IF_FALSE (
560
- start_pos < k_cache.size (2 ),
553
+ start_pos < k_cache.size (1 ),
561
554
" start_pos must be less than key cache at dim 1" );
562
555
563
556
ET_LOG_MSG_AND_RETURN_IF_FALSE (
564
- start_pos < v_cache.size (2 ),
557
+ start_pos < v_cache.size (1 ),
565
558
" start_pos must be less than value cache at dim 1" );
566
559
567
560
ET_LOG_MSG_AND_RETURN_IF_FALSE (
568
- (start_pos + seq_length) < k_cache.size (2 ),
561
+ (start_pos + seq_length) < k_cache.size (1 ),
569
562
" start_post + seq_length must be less than max seq length supported by key cache."
570
563
" start pos: %" PRId64 " , seq_length: %" PRId64
571
564
" ."
572
565
" key cache size: %zd" ,
573
566
start_pos,
574
567
seq_length,
575
- k_cache.size (2 ));
568
+ k_cache.size (1 ));
576
569
577
570
ET_LOG_MSG_AND_RETURN_IF_FALSE (
578
- (start_pos + seq_length) < v_cache.size (2 ),
571
+ (start_pos + seq_length) < v_cache.size (1 ),
579
572
" start_post + seq_length must be less than max seq length supported by key cache."
580
573
" start pos: %" PRId64 " , seq_length: %" PRId64
581
574
" ."
@@ -600,14 +593,13 @@ bool validate_cache_params(
600
593
void update_cache (
601
594
const Tensor& projected_value,
602
595
const Tensor& cache,
603
- int64_t layer_id,
604
596
int64_t start_pos,
605
597
int64_t seq_length) {
606
598
ET_CHECK_MSG (seq_length == 1 , " seq_length must be 1" );
607
599
ET_CHECK_MSG (
608
600
projected_value.size (0 ) == 1 ,
609
601
" projected_value must have batch size of 1" );
610
- ET_CHECK_MSG (cache.size (1 ) == 1 , " cache must have batch size of 1" );
602
+ ET_CHECK_MSG (cache.size (0 ) == 1 , " cache must have batch size of 1" );
611
603
ET_CHECK_MSG (
612
604
is_default_dim_order (
613
605
projected_value.dim_order ().data (), projected_value.dim ()),
@@ -619,10 +611,8 @@ void update_cache(
619
611
ET_CHECK_MSG (cache_data, " cache data is null" );
620
612
621
613
auto strides = cache.strides ();
622
- exec_aten::StridesType layer_stride = strides[0 ];
623
- exec_aten::StridesType seq_dim_stride = strides[2 ];
624
- exec_aten::SizesType pos_offset =
625
- layer_id * layer_stride + start_pos * seq_dim_stride;
614
+ exec_aten::StridesType seq_dim_stride = strides[1 ];
615
+ exec_aten::SizesType pos_offset = start_pos * seq_dim_stride;
626
616
exec_aten::SizesType pos_offset_bytes =
627
617
pos_offset * projected_value.element_size ();
628
618
exec_aten::SizesType num_bytes =
@@ -713,19 +703,16 @@ Tensor& flash_attention_kernel_out(
713
703
@param[in] key_cache Cache of previous v_projected.
714
704
Format [n_layers, batch size, max_seq_len, num heads, head dim]
715
705
....
716
- @param[in] layer_id which layer this call belongs to.
717
- Used to updated appropriate entry of kv cache
718
- @param[in] start_pos sequence position
719
- @param[in] seq_len Seq length. e.g. seq_len dim of q_projected.
706
+ @param[in] start_pos: sequence position
707
+ @param[in] seq_len: Seq length. e.g. seq_len dim of q_projected.
720
708
*/
721
709
Tensor& sdpa_with_kv_cache_out (
722
710
RuntimeContext& ctx,
723
711
const Tensor& q_projected,
724
712
const Tensor& k_projected,
725
713
const Tensor& v_projected,
726
- const Tensor& key_cache,
727
- const Tensor& value_cache,
728
- const int64_t layer_id, // THis should be gone with buffer based impl
714
+ Tensor& key_cache,
715
+ Tensor& value_cache,
729
716
const int64_t start_pos,
730
717
const int64_t seq_len,
731
718
const optional<Tensor>& attn_mask,
@@ -737,34 +724,31 @@ Tensor& sdpa_with_kv_cache_out(
737
724
(void )ctx;
738
725
ET_KERNEL_CHECK (
739
726
ctx,
740
- validate_cache_params (
741
- key_cache, value_cache, layer_id, start_pos, seq_len),
727
+ validate_cache_params (key_cache, value_cache, start_pos, seq_len),
742
728
InvalidArgument,
743
729
output);
744
730
745
731
ET_CHECK_MSG (q_projected.dim () == 4 , " query must be a 4D tensor" );
746
732
747
- update_cache (k_projected, key_cache, layer_id, start_pos, seq_len);
748
- update_cache (v_projected, value_cache, layer_id, start_pos, seq_len);
733
+ update_cache (k_projected, key_cache, start_pos, seq_len);
734
+ update_cache (v_projected, value_cache, start_pos, seq_len);
749
735
750
736
auto q_seq_len = q_projected.size (1 );
751
737
752
738
std::array<exec_aten::DimOrderType, util::kKVDim > sliced_key_dim_order{
753
739
0 , 1 , 2 , 3 };
754
740
std::array<exec_aten::SizesType, util::kKVDim > sliced_key_sizes;
755
- sliced_key_sizes[0 ] = key_cache.size (1 );
741
+ sliced_key_sizes[0 ] = key_cache.size (0 );
756
742
sliced_key_sizes[1 ] = start_pos + seq_len; // key_cache.size(2);
757
- sliced_key_sizes[2 ] = key_cache.size (3 );
758
- sliced_key_sizes[3 ] = key_cache.size (4 );
743
+ sliced_key_sizes[2 ] = key_cache.size (2 );
744
+ sliced_key_sizes[3 ] = key_cache.size (3 );
759
745
std::array<exec_aten::StridesType, util::kKVDim > sliced_key_strides;
760
746
dim_order_to_stride_nocheck (
761
747
sliced_key_sizes.data (),
762
748
sliced_key_dim_order.data (),
763
749
util::kKVDim ,
764
750
sliced_key_strides.data ());
765
- void * key_cache_data = reinterpret_cast <void *>(
766
- reinterpret_cast <ptrdiff_t >(key_cache.mutable_data_ptr ()) +
767
- layer_id * key_cache.strides ()[0 ] * key_cache.element_size ());
751
+ void * key_cache_data = key_cache.mutable_data_ptr ();
768
752
TensorImpl k_impl = TensorImpl (
769
753
key_cache.scalar_type (),
770
754
util::kKVDim ,
@@ -778,19 +762,17 @@ Tensor& sdpa_with_kv_cache_out(
778
762
std::array<exec_aten::DimOrderType, util::kKVDim > sliced_value_dim_order{
779
763
0 , 1 , 2 , 3 };
780
764
std::array<exec_aten::SizesType, util::kKVDim > sliced_value_sizes;
781
- sliced_value_sizes[0 ] = value_cache.size (1 );
765
+ sliced_value_sizes[0 ] = value_cache.size (0 );
782
766
sliced_value_sizes[1 ] = start_pos + seq_len; // value_cache.size(2);
783
- sliced_value_sizes[2 ] = value_cache.size (3 );
784
- sliced_value_sizes[3 ] = value_cache.size (4 );
767
+ sliced_value_sizes[2 ] = value_cache.size (2 );
768
+ sliced_value_sizes[3 ] = value_cache.size (3 );
785
769
std::array<exec_aten::StridesType, util::kKVDim > sliced_value_strides;
786
770
dim_order_to_stride_nocheck (
787
771
sliced_value_sizes.data (),
788
772
sliced_value_dim_order.data (),
789
773
util::kKVDim ,
790
774
sliced_value_strides.data ());
791
- void * value_cache_data = reinterpret_cast <void *>(
792
- reinterpret_cast <ptrdiff_t >(value_cache.mutable_data_ptr ()) +
793
- layer_id * value_cache.strides ()[0 ] * value_cache.element_size ());
775
+ void * value_cache_data = value_cache.mutable_data_ptr ();
794
776
TensorImpl value_impl = TensorImpl (
795
777
value_cache.scalar_type (),
796
778
util::kKVDim ,
0 commit comments