Skip to content

Commit 13d2fa7

Browse files
committed
Final final fix
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 1c3b5ae commit 13d2fa7

File tree

3 files changed

+113
-4
lines changed

3 files changed

+113
-4
lines changed

examples/models/llama2/custom_ops/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ include(${EXECUTORCH_SRCS_FILE})
4444
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
4545

4646
# Custom op libraries
47-
set(custom_ops_libs extension_module executorch)
47+
set(custom_ops_libs executorch)
4848
list(APPEND custom_ops_libs pthreadpool)
4949
list(APPEND custom_ops_libs cpuinfo)
5050
list(APPEND custom_ops_libs cpublas)

examples/models/llama2/custom_ops/op_sdpa_aot.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,18 @@ TORCH_LIBRARY(llama, m) {
8585
m.def(
8686
"sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
8787
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
88-
"float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor",
89-
&torch::executor::native::sdpa_with_kv_cache_aten);
88+
"float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor");
9089
m.def(
9190
"sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
9291
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
93-
"float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)",
92+
"float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)");
93+
}
94+
95+
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
96+
m.impl(
97+
"sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten);
98+
m.impl(
99+
"sdpa_with_kv_cache.out",
94100
WRAP_TO_ATEN(
95101
torch::executor::native::sdpa_with_kv_cache_out_no_context, 11));
96102
}

examples/models/llama2/custom_ops/sdpa_with_kv_cache.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
import torch
1515

16+
from torch.library import impl
17+
1618
try:
1719
op = torch.ops.llama.sdpa_with_kv_cache.default
1820
assert op is not None
@@ -30,3 +32,104 @@
3032
torch.ops.load_library(full_path)
3133
op = torch.ops.llama.sdpa_with_kv_cache.default
3234
assert op is not None
35+
36+
custom_ops_lib = torch.library.Library("llama", "IMPL")
37+
38+
39+
def _validate_params(
40+
query,
41+
key,
42+
value,
43+
key_cache,
44+
value_cache,
45+
start_pos,
46+
seq_len,
47+
attn_mask,
48+
drpout_p,
49+
is_causal,
50+
scale,
51+
):
52+
assert (
53+
query.dim() == 4
54+
), f"Expected query to be 4 dimensional but got {query.dim()} dimensions."
55+
assert (
56+
key.dim() == 4
57+
), f"Expected key to be 4 dimensional but got {key.dim()} dimensions."
58+
assert (
59+
value.dim() == 4
60+
), f"Expected value to be 4 dimensional but got {value.dim()} dimensions."
61+
62+
assert (
63+
query.dtype == torch.float32
64+
), f"Expected query to be float32 but got {query.dtype}"
65+
assert key.dtype == torch.float32, f"Expected key to be float32 but got {key.dtype}"
66+
assert (
67+
value.dtype == torch.float32
68+
), f"Expected value to be float32 but got {value.dtype}"
69+
70+
assert (
71+
key_cache.dim() == 4
72+
), f"Expected key_cache to be 4 dimensional but got {key_cache.dim()}"
73+
assert (
74+
value_cache.dim() == 4
75+
), f"Expected value_cache to be 4 dimensional but got {value_cache.dim()}"
76+
77+
assert (
78+
key_cache.dtype == torch.float32
79+
), f"Expected key_cache to be float32 but got {key_cache.dtype}"
80+
assert (
81+
value_cache.dtype == torch.float32
82+
), f"Expected value_cache to be float32 but got {value_cache.dtype}"
83+
84+
assert (
85+
key_cache.size() == value_cache.size()
86+
), f"Key cache and value cache must have same size but got {key_cache.size()} and {value_cache.size()}"
87+
88+
# These asserts are real but they require me to add constrain_as_size/value calls to the model and I dont want to do that right now
89+
# assert start_pos < key_cache.size(
90+
# 1
91+
# ), f"Start position {start_pos} must be less than sequence length {key_cache.size(2)}"
92+
# assert (start_pos + seq_len) < key_cache.size(
93+
# 1
94+
# ), f"Start position + length = {start_pos + seq_len} must be less than sequence length {key_cache.size(2)}"
95+
96+
assert seq_len == 1, "Only support seq_len = 1 for now."
97+
98+
if attn_mask is not None:
99+
assert (
100+
attn_mask.dim() == 2
101+
), f"Expected attn_mask to be 2 dimensional but got {attn_mask.dim()} dimensions."
102+
assert (attn_mask.dtype == torch.float32) or (
103+
attn_mask.dtype == torch.float16
104+
), f"Expected attn_mask to be float but got {attn_mask.dtype}"
105+
106+
107+
@impl(custom_ops_lib, "sdpa_with_kv_cache", "Meta")
108+
def sdpa_with_kv_cache_meta(
109+
query,
110+
key,
111+
value,
112+
key_cache,
113+
value_cache,
114+
start_pos,
115+
seq_len,
116+
attn_mask=None,
117+
drpout_p=0.0,
118+
is_causal=False,
119+
scale=None,
120+
):
121+
_validate_params(
122+
query,
123+
key,
124+
value,
125+
key_cache,
126+
value_cache,
127+
start_pos,
128+
seq_len,
129+
attn_mask,
130+
drpout_p,
131+
is_causal,
132+
scale,
133+
)
134+
135+
return torch.empty_like(query)

0 commit comments

Comments
 (0)