Skip to content

Commit 29364c4

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Refactoring sdpa (#5667)
Summary: Pull Request resolved: #5667 - Change variable name from is_sdpa_with_kv_cache to is_seq_at_dim_1 to be more meaningful - remove seq_len as it can be derived from the size of query ghstack-source-id: 245751540 exported-using-ghexport Reviewed By: metascroy Differential Revision: D62623243 fbshipit-source-id: 5715f44f9a45d43e9959cfa56ed0023b72fb55e7
1 parent bca3ad6 commit 29364c4

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ void cpu_flash_attention(
224224
bool is_causal,
225225
const optional<Tensor>& attn_mask,
226226
const optional<double>& scale,
227-
bool is_with_kv_cache = false,
227+
bool is_seq_at_dim_1 = false,
228228
const int64_t start_pos = 0) {
229229
(void)dropout_p;
230230
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
@@ -265,7 +265,7 @@ void cpu_flash_attention(
265265
int64_t kvSize = value.size(2);
266266
int64_t num_heads_kv = key.size(1);
267267

268-
if (is_with_kv_cache) {
268+
if (is_seq_at_dim_1) {
269269
num_head = query.size(2);
270270
num_heads_kv = key.size(2);
271271
qSize = query.size(1);
@@ -311,7 +311,7 @@ void cpu_flash_attention(
311311
int64_t qStrideH = strides[1];
312312
int64_t qStrideM = strides[2];
313313

314-
if (is_with_kv_cache) {
314+
if (is_seq_at_dim_1) {
315315
qStrideH = strides[2];
316316
qStrideM = strides[1];
317317
}
@@ -321,7 +321,7 @@ void cpu_flash_attention(
321321
int64_t kStrideH = strides[1];
322322
int64_t kStrideN = strides[2];
323323

324-
if (is_with_kv_cache) {
324+
if (is_seq_at_dim_1) {
325325
kStrideH = strides[2];
326326
kStrideN = strides[1];
327327
}
@@ -331,7 +331,7 @@ void cpu_flash_attention(
331331
int64_t vStrideH = strides[1];
332332
int64_t vStrideN = strides[2];
333333

334-
if (is_with_kv_cache) {
334+
if (is_seq_at_dim_1) {
335335
vStrideH = strides[2];
336336
vStrideN = strides[1];
337337
}
@@ -341,7 +341,7 @@ void cpu_flash_attention(
341341
int64_t oStrideH = strides[1];
342342
int64_t oStrideM = strides[2];
343343

344-
if (is_with_kv_cache) {
344+
if (is_seq_at_dim_1) {
345345
oStrideH = strides[2];
346346
oStrideM = strides[1];
347347
}
@@ -776,7 +776,6 @@ Tensor& custom_sdpa_out(
776776
const Tensor& k,
777777
const Tensor& v,
778778
const int64_t start_pos,
779-
const int64_t seq_len,
780779
const optional<Tensor>& attn_mask,
781780
const double dropout_p,
782781
const bool is_causal,
@@ -792,6 +791,7 @@ Tensor& custom_sdpa_out(
792791

793792
ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor");
794793

794+
const int64_t seq_len = q.size(1);
795795
auto q_seq_len = q.size(1);
796796

797797
// Refactor the following into create_view util perhaps using
@@ -870,7 +870,7 @@ Tensor& custom_sdpa_out(
870870
is_causal,
871871
attn_mask,
872872
scale,
873-
true,
873+
true, /* is_seq_at_dim_1 */
874874
start_pos);
875875
} else if (q_seq_len >= 192) {
876876
cpu_flash_attention<CTYPE, 64, 512>(
@@ -882,7 +882,7 @@ Tensor& custom_sdpa_out(
882882
is_causal,
883883
attn_mask,
884884
scale,
885-
true,
885+
true, /* is_seq_at_dim_1 */
886886
start_pos);
887887
} else {
888888
cpu_flash_attention<CTYPE, 32, 512>(
@@ -894,7 +894,7 @@ Tensor& custom_sdpa_out(
894894
is_causal,
895895
attn_mask,
896896
scale,
897-
true,
897+
true, /* is_seq_at_dim_1 */
898898
start_pos);
899899
}
900900
});
@@ -1017,7 +1017,6 @@ Tensor& sdpa_with_kv_cache_out(
10171017
key_cache,
10181018
value_cache,
10191019
start_pos,
1020-
seq_len,
10211020
attn_mask,
10221021
dropout_p,
10231022
is_causal,

0 commit comments

Comments
 (0)