Skip to content

Commit 5e23c33

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
off by one error in sdpa cache op (#2689)
Summary: Pull Request resolved: #2689 start_pos is where we currently are. so when start_pos = 127 we are currently generating the 128th token bypass-github-export-checks Reviewed By: kimishpatel Differential Revision: D55370500 fbshipit-source-id: 412b1638cc320a4df56dcc38995f3836ed5c425b
1 parent 9baa2df commit 5e23c33

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

examples/models/llama2/custom_ops/op_sdpa.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ bool validate_cache_params(
558558
"start_pos must be less than value cache at dim 1");
559559

560560
ET_LOG_MSG_AND_RETURN_IF_FALSE(
561-
(start_pos + seq_length) < k_cache.size(1),
561+
(start_pos + seq_length) <= k_cache.size(1),
562562
"start_post + seq_length must be less than max seq length supported by key cache."
563563
"start pos: %" PRId64 ", seq_length: %" PRId64
564564
"."
@@ -568,14 +568,14 @@ bool validate_cache_params(
568568
k_cache.size(1));
569569

570570
ET_LOG_MSG_AND_RETURN_IF_FALSE(
571-
(start_pos + seq_length) < v_cache.size(1),
571+
(start_pos + seq_length) <= v_cache.size(1),
572572
"start_post + seq_length must be less than max seq length supported by key cache."
573573
"start pos: %" PRId64 ", seq_length: %" PRId64
574574
"."
575575
"value cache size: %zd",
576576
start_pos,
577577
seq_length,
578-
v_cache.size(2));
578+
v_cache.size(1));
579579

580580
// Make sure they are in contiguous dim order
581581
ET_LOG_MSG_AND_RETURN_IF_FALSE(

0 commit comments

Comments
 (0)