Skip to content

Commit e7109f4

Browse files
committed
Update on "[Executorch][llama] Allow custom sdpa op replacement pass to leverage attention mask"
Previously we assumed that the custom sdpa always does causal attention. This diff adds option to this module swap pass to make custom sdpa leverage attention mask instead of causal. Differential Revision: [D73222736](https://our.internmc.facebook.com/intern/diff/D73222736/) [ghstack-poisoned]
2 parents 013874c + 23e0599 commit e7109f4

File tree

2 files changed

+16
-291
lines changed

2 files changed

+16
-291
lines changed

extension/llm/custom_ops/op_sdpa_with_kv_cache_test.cpp

Lines changed: 0 additions & 283 deletions
Original file line numberDiff line numberDiff line change
@@ -524,289 +524,6 @@ TEST(OpScaledDotProductAttentionTest, LargerTest) {
524524
EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_5, 1e-4, 1e-4);
525525
}
526526

527-
TEST(OpScaledDotProductAttentionTest, BasicTestWithAttnMask) {
528-
TensorFactory<executorch::aten::ScalarType::Float> tfFloat;
529-
530-
executorch::aten::Tensor query = tfFloat.make(
531-
{1, 1, 4, 4},
532-
{0.8823,
533-
0.9150,
534-
0.3829,
535-
0.9593,
536-
0.3904,
537-
0.6009,
538-
0.2566,
539-
0.7936,
540-
0.9408,
541-
0.1332,
542-
0.9346,
543-
0.5936,
544-
0.8694,
545-
0.5677,
546-
0.7411,
547-
0.4294});
548-
executorch::aten::Tensor key = tfFloat.make(
549-
{1, 1, 4, 4},
550-
{0.8854,
551-
0.5739,
552-
0.2666,
553-
0.6274,
554-
0.2696,
555-
0.4414,
556-
0.2969,
557-
0.8317,
558-
0.1053,
559-
0.2695,
560-
0.3588,
561-
0.1994,
562-
0.5472,
563-
0.0062,
564-
0.9516,
565-
0.0753});
566-
executorch::aten::Tensor value = tfFloat.make(
567-
{1, 1, 4, 4},
568-
{0.8860,
569-
0.5832,
570-
0.3376,
571-
0.8090,
572-
0.5779,
573-
0.9040,
574-
0.5547,
575-
0.3423,
576-
0.6343,
577-
0.3644,
578-
0.7104,
579-
0.9464,
580-
0.7890,
581-
0.2814,
582-
0.7886,
583-
0.5895});
584-
executorch::aten::Tensor attn_mask = tfFloat.make({1, 1}, {0});
585-
executorch::aten::Tensor key_cache_0 = tfFloat.zeros({1, 5, 4, 4});
586-
executorch::aten::Tensor value_cache_0 = tfFloat.zeros({1, 5, 4, 4});
587-
executorch::aten::Tensor key_cache_1 = tfFloat.zeros({1, 5, 4, 4});
588-
executorch::aten::Tensor value_cache_1 = tfFloat.zeros({1, 5, 4, 4});
589-
executorch::aten::Tensor key_cache_2 = tfFloat.zeros({1, 5, 4, 4});
590-
executorch::aten::Tensor value_cache_2 = tfFloat.zeros({1, 5, 4, 4});
591-
double dropout_p = 0;
592-
bool is_causal = false;
593-
executorch::aten::optional<double> scale;
594-
595-
// start pos: 0 layer id 0
596-
executorch::aten::Tensor ret_expected_0 = tfFloat.make(
597-
{1, 1, 4, 4},
598-
{0.8860,
599-
0.5832,
600-
0.3376,
601-
0.8090,
602-
0.5779,
603-
0.9040,
604-
0.5547,
605-
0.3423,
606-
0.6343,
607-
0.3644,
608-
0.7104,
609-
0.9464,
610-
0.7890,
611-
0.2814,
612-
0.7886,
613-
0.5895});
614-
615-
std::vector<int32_t> out_size = {1, 1, 4, 4};
616-
executorch::aten::Tensor out = tfFloat.zeros(out_size);
617-
executorch::aten::Tensor ret = op_sdpa_with_kv_cache(
618-
query,
619-
key,
620-
value,
621-
key_cache_0,
622-
value_cache_0,
623-
0,
624-
1,
625-
attn_mask,
626-
dropout_p,
627-
is_causal,
628-
scale,
629-
out);
630-
EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_0, 1e-4, 1e-4);
631-
632-
// start pos: 0 layer id 2
633-
executorch::aten::Tensor ret_expected_1 = tfFloat.make(
634-
{1, 1, 4, 4},
635-
{0.8860,
636-
0.5832,
637-
0.3376,
638-
0.8090,
639-
0.5779,
640-
0.9040,
641-
0.5547,
642-
0.3423,
643-
0.6343,
644-
0.3644,
645-
0.7104,
646-
0.9464,
647-
0.7890,
648-
0.2814,
649-
0.7886,
650-
0.5895});
651-
out = tfFloat.zeros(out_size);
652-
ret = op_sdpa_with_kv_cache(
653-
query,
654-
key,
655-
value,
656-
key_cache_2,
657-
value_cache_2,
658-
0,
659-
1,
660-
attn_mask,
661-
dropout_p,
662-
is_causal,
663-
scale,
664-
out);
665-
EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_1, 1e-4, 1e-4);
666-
667-
attn_mask = tfFloat.make({1, 2}, {0, 0});
668-
// start pos: 1 layer id 0
669-
executorch::aten::Tensor ret_expected_2 = tfFloat.make(
670-
{1, 1, 4, 4},
671-
{0.8860,
672-
0.5832,
673-
0.3376,
674-
0.8090,
675-
0.5779,
676-
0.9040,
677-
0.5547,
678-
0.3423,
679-
0.6343,
680-
0.3644,
681-
0.7104,
682-
0.9464,
683-
0.7890,
684-
0.2814,
685-
0.7886,
686-
0.5895});
687-
out = tfFloat.zeros(out_size);
688-
ret = op_sdpa_with_kv_cache(
689-
query,
690-
key,
691-
value,
692-
key_cache_0,
693-
value_cache_0,
694-
1,
695-
1,
696-
attn_mask,
697-
dropout_p,
698-
is_causal,
699-
scale,
700-
out);
701-
EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_2, 1e-4, 1e-4);
702-
703-
// start pos: 1 layer id 1
704-
executorch::aten::Tensor ret_expected_3 = tfFloat.make(
705-
{1, 1, 4, 4},
706-
{0.6486,
707-
0.4270,
708-
0.2472,
709-
0.5922,
710-
0.3669,
711-
0.5740,
712-
0.3522,
713-
0.2173,
714-
0.3635,
715-
0.2088,
716-
0.4071,
717-
0.5423,
718-
0.5110,
719-
0.1822,
720-
0.5107,
721-
0.3817});
722-
out = tfFloat.zeros(out_size);
723-
ret = op_sdpa_with_kv_cache(
724-
query,
725-
key,
726-
value,
727-
key_cache_1,
728-
value_cache_1,
729-
1,
730-
1,
731-
attn_mask,
732-
dropout_p,
733-
is_causal,
734-
scale,
735-
out);
736-
EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_3, 1e-4, 1e-4);
737-
738-
attn_mask = tfFloat.make({1, 3}, {0, 0, 0});
739-
// start pos: 2 layer id 1
740-
executorch::aten::Tensor ret_expected_4 = tfFloat.make(
741-
{1, 1, 4, 4},
742-
{0.7490,
743-
0.4930,
744-
0.2854,
745-
0.6838,
746-
0.4489,
747-
0.7021,
748-
0.4308,
749-
0.2659,
750-
0.4622,
751-
0.2655,
752-
0.5176,
753-
0.6895,
754-
0.6202,
755-
0.2212,
756-
0.6199,
757-
0.4634});
758-
out = tfFloat.zeros(out_size);
759-
ret = op_sdpa_with_kv_cache(
760-
query,
761-
key,
762-
value,
763-
key_cache_1,
764-
value_cache_1,
765-
2,
766-
1,
767-
attn_mask,
768-
dropout_p,
769-
is_causal,
770-
scale,
771-
out);
772-
EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_4, 1e-4, 1e-4);
773-
774-
// start pos: 2 layer id 2
775-
executorch::aten::Tensor ret_expected_5 = tfFloat.make(
776-
{1, 1, 4, 4},
777-
{0.7490,
778-
0.4930,
779-
0.2854,
780-
0.6838,
781-
0.4489,
782-
0.7021,
783-
0.4308,
784-
0.2659,
785-
0.4622,
786-
0.2655,
787-
0.5176,
788-
0.6895,
789-
0.6202,
790-
0.2212,
791-
0.6199,
792-
0.4634});
793-
out = tfFloat.zeros(out_size);
794-
ret = op_sdpa_with_kv_cache(
795-
query,
796-
key,
797-
value,
798-
key_cache_2,
799-
value_cache_2,
800-
2,
801-
1,
802-
attn_mask,
803-
dropout_p,
804-
is_causal,
805-
scale,
806-
out);
807-
EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_5, 1e-4, 1e-4);
808-
}
809-
810527
TEST(OpScaledDotProductAttentionTest, SequenceTest) {
811528
TensorFactory<executorch::aten::ScalarType::Float> tfFloat;
812529

extension/llm/custom_ops/test_sdpa_with_kv_cache.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,14 @@ def test_sdpa_with_cache_no_mqa_1(self):
6767
)
6868
if self.use_mask_with_custom_op:
6969
attn_mask = attn_mask.contiguous()
70+
sliced_k_cache = self.k_cache[:, : start_pos + seq_len, :, :]
71+
sliced_v_cache = self.v_cache[:, : start_pos + seq_len, :, :]
7072
op_output = torch.ops.llama.sdpa_with_kv_cache(
7173
q,
7274
k,
7375
v,
74-
self.k_cache,
75-
self.v_cache,
76+
sliced_k_cache,
77+
sliced_v_cache,
7678
start_pos,
7779
seq_len,
7880
attn_mask,
@@ -108,12 +110,14 @@ def test_sdpa_with_cache_no_mqa_2(self):
108110
)
109111
if self.use_mask_with_custom_op:
110112
attn_mask = attn_mask.contiguous()
113+
sliced_k_cache = self.k_cache[:, : start_pos + seq_len, :, :]
114+
sliced_v_cache = self.v_cache[:, : start_pos + seq_len, :, :]
111115
op_output = torch.ops.llama.sdpa_with_kv_cache(
112116
q,
113117
k,
114118
v,
115-
self.k_cache,
116-
self.v_cache,
119+
sliced_k_cache,
120+
sliced_v_cache,
117121
start_pos,
118122
seq_len,
119123
attn_mask,
@@ -150,12 +154,14 @@ def test_sdpa_with_cache_no_mqa_3(self):
150154
)
151155
if self.use_mask_with_custom_op:
152156
attn_mask = attn_mask.contiguous()
157+
sliced_k_cache = self.k_cache[:, : start_pos + seq_len, :, :]
158+
sliced_v_cache = self.v_cache[:, : start_pos + seq_len, :, :]
153159
op_output = torch.ops.llama.sdpa_with_kv_cache(
154160
q,
155161
k,
156162
v,
157-
self.k_cache,
158-
self.v_cache,
163+
sliced_k_cache,
164+
sliced_v_cache,
159165
start_pos,
160166
seq_len,
161167
attn_mask,
@@ -191,12 +197,14 @@ def test_sdpa_with_cache_no_mqa_4(self):
191197
)
192198
if self.use_mask_with_custom_op:
193199
attn_mask = attn_mask.contiguous()
200+
sliced_k_cache = self.k_cache[:, : start_pos + seq_len, :, :]
201+
sliced_v_cache = self.v_cache[:, : start_pos + seq_len, :, :]
194202
op_output = torch.ops.llama.sdpa_with_kv_cache(
195203
q,
196204
k,
197205
v,
198-
self.k_cache,
199-
self.v_cache,
206+
sliced_k_cache,
207+
sliced_v_cache,
200208
start_pos,
201209
seq_len,
202210
attn_mask,

0 commit comments

Comments
 (0)