@@ -224,7 +224,7 @@ void cpu_flash_attention(
224
224
bool is_causal,
225
225
const optional<Tensor>& attn_mask,
226
226
const optional<double >& scale,
227
- bool is_with_kv_cache = false ,
227
+ bool is_seq_at_dim_1 = false ,
228
228
const int64_t start_pos = 0 ) {
229
229
(void )dropout_p;
230
230
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
@@ -265,7 +265,7 @@ void cpu_flash_attention(
265
265
int64_t kvSize = value.size (2 );
266
266
int64_t num_heads_kv = key.size (1 );
267
267
268
- if (is_with_kv_cache ) {
268
+ if (is_seq_at_dim_1 ) {
269
269
num_head = query.size (2 );
270
270
num_heads_kv = key.size (2 );
271
271
qSize = query.size (1 );
@@ -311,7 +311,7 @@ void cpu_flash_attention(
311
311
int64_t qStrideH = strides[1 ];
312
312
int64_t qStrideM = strides[2 ];
313
313
314
- if (is_with_kv_cache ) {
314
+ if (is_seq_at_dim_1 ) {
315
315
qStrideH = strides[2 ];
316
316
qStrideM = strides[1 ];
317
317
}
@@ -321,7 +321,7 @@ void cpu_flash_attention(
321
321
int64_t kStrideH = strides[1 ];
322
322
int64_t kStrideN = strides[2 ];
323
323
324
- if (is_with_kv_cache ) {
324
+ if (is_seq_at_dim_1 ) {
325
325
kStrideH = strides[2 ];
326
326
kStrideN = strides[1 ];
327
327
}
@@ -331,7 +331,7 @@ void cpu_flash_attention(
331
331
int64_t vStrideH = strides[1 ];
332
332
int64_t vStrideN = strides[2 ];
333
333
334
- if (is_with_kv_cache ) {
334
+ if (is_seq_at_dim_1 ) {
335
335
vStrideH = strides[2 ];
336
336
vStrideN = strides[1 ];
337
337
}
@@ -341,7 +341,7 @@ void cpu_flash_attention(
341
341
int64_t oStrideH = strides[1 ];
342
342
int64_t oStrideM = strides[2 ];
343
343
344
- if (is_with_kv_cache ) {
344
+ if (is_seq_at_dim_1 ) {
345
345
oStrideH = strides[2 ];
346
346
oStrideM = strides[1 ];
347
347
}
@@ -776,7 +776,6 @@ Tensor& custom_sdpa_out(
776
776
const Tensor& k,
777
777
const Tensor& v,
778
778
const int64_t start_pos,
779
- const int64_t seq_len,
780
779
const optional<Tensor>& attn_mask,
781
780
const double dropout_p,
782
781
const bool is_causal,
@@ -792,6 +791,7 @@ Tensor& custom_sdpa_out(
792
791
793
792
ET_CHECK_MSG (q.dim () == 4 , " query must be a 4D tensor" );
794
793
794
+ const int64_t seq_len = q.size (1 );
795
795
auto q_seq_len = q.size (1 );
796
796
797
797
// Refactor the following into create_view util perhaps using
@@ -870,7 +870,7 @@ Tensor& custom_sdpa_out(
870
870
is_causal,
871
871
attn_mask,
872
872
scale,
873
- true ,
873
+ true , /* is_seq_at_dim_1 */
874
874
start_pos);
875
875
} else if (q_seq_len >= 192 ) {
876
876
cpu_flash_attention<CTYPE, 64 , 512 >(
@@ -882,7 +882,7 @@ Tensor& custom_sdpa_out(
882
882
is_causal,
883
883
attn_mask,
884
884
scale,
885
- true ,
885
+ true , /* is_seq_at_dim_1 */
886
886
start_pos);
887
887
} else {
888
888
cpu_flash_attention<CTYPE, 32 , 512 >(
@@ -894,7 +894,7 @@ Tensor& custom_sdpa_out(
894
894
is_causal,
895
895
attn_mask,
896
896
scale,
897
- true ,
897
+ true , /* is_seq_at_dim_1 */
898
898
start_pos);
899
899
}
900
900
});
@@ -1017,7 +1017,6 @@ Tensor& sdpa_with_kv_cache_out(
1017
1017
key_cache,
1018
1018
value_cache,
1019
1019
start_pos,
1020
- seq_len,
1021
1020
attn_mask,
1022
1021
dropout_p,
1023
1022
is_causal,
0 commit comments