Skip to content

Commit 5f324ce

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Add update_quantized_cache op (#5527)
Summary: Pull Request resolved: #5527 Why? - ton of copies due to functionalization - mutable buffer support without such custom inplace ops will results in giant copies at the end - Making inplace ops work will likely take longer and not clear safe path ghstack-source-id: 245703481 exported-using-ghexport Reviewed By: metascroy Differential Revision: D62301838 fbshipit-source-id: ad9d0cf14f50bb369409976acdd42f860a10b1a7
1 parent d459011 commit 5f324ce

File tree

7 files changed

+459
-2
lines changed

7 files changed

+459
-2
lines changed

extension/llm/custom_ops/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,19 @@ runtime.python_test(
2222
],
2323
)
2424

25+
runtime.python_test(
26+
name = "test_update_quantized_cache",
27+
srcs = [
28+
"test_update_quantized_cache.py",
29+
],
30+
preload_deps = [
31+
":custom_ops_aot_lib",
32+
],
33+
deps = [
34+
"//caffe2:torch",
35+
],
36+
)
37+
2538
runtime.python_test(
2639
name = "test_preprocess_custom_ops",
2740
srcs = [

extension/llm/custom_ops/op_sdpa_aot.cpp

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
1010
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
1111
#include <executorch/extension/llm/custom_ops/op_sdpa.h>
12+
#include <executorch/extension/llm/custom_ops/op_update_quantized_cache.h>
1213

1314
#include <torch/library.h>
1415

1516
namespace torch {
1617
namespace executor {
1718

1819
namespace native {
19-
namespace {
2020
Tensor& sdpa_with_kv_cache_out_no_context(
2121
const Tensor& q_projected,
2222
const Tensor& k_projected,
@@ -81,7 +81,27 @@ at::Tensor sdpa_with_kv_cache_aten(
8181
output);
8282
return output;
8383
}
84-
} // namespace
84+
85+
Tensor& update_quantized_cache_out_no_context(
86+
const Tensor& value,
87+
Tensor& cache,
88+
const int64_t start_pos,
89+
Tensor& output) {
90+
exec_aten::RuntimeContext context{};
91+
return torch::executor::native::update_quantized_cache_out(
92+
context, value, cache, start_pos, output);
93+
}
94+
95+
at::Tensor update_quantized_cache_aten(
96+
const at::Tensor& value,
97+
at::Tensor& cache,
98+
const int64_t start_pos) {
99+
auto output = at::empty({1});
100+
WRAP_TO_ATEN(update_quantized_cache_out_no_context, 3)
101+
(value, cache, start_pos, output);
102+
return output;
103+
}
104+
85105
} // namespace native
86106
} // namespace executor
87107
} // namespace torch
@@ -95,6 +115,12 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
95115
"sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, "
96116
"Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, "
97117
"float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)");
118+
m.def(
119+
"update_quantized_cache(Tensor value, Tensor(a!) cache, "
120+
"SymInt start_pos) -> Tensor");
121+
m.def(
122+
"update_quantized_cache.out(Tensor value, Tensor(a!) cache, "
123+
"SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)");
98124
}
99125

100126
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
@@ -105,3 +131,14 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
105131
WRAP_TO_ATEN(
106132
torch::executor::native::sdpa_with_kv_cache_out_no_context, 11));
107133
}
134+
135+
// TODO: Rename this file to op_custom_ops_aot.cpp
136+
TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
137+
m.impl(
138+
"update_quantized_cache",
139+
torch::executor::native::update_quantized_cache_aten);
140+
m.impl(
141+
"update_quantized_cache.out",
142+
WRAP_TO_ATEN(
143+
torch::executor::native::update_quantized_cache_out_no_context, 3));
144+
}
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/llm/custom_ops/op_update_quantized_cache.h>
10+
11+
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
12+
// @lint-ignore CLANGTIDY facebook-unused-include-check
13+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14+
15+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
16+
17+
namespace torch {
18+
namespace executor {
19+
20+
namespace native {
21+
22+
namespace {
23+
bool validate_cache_params(
24+
const Tensor& quantized_value,
25+
const Tensor& quantized_cache,
26+
int64_t start_pos,
27+
int64_t seq_length) {
28+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
29+
quantized_cache.dim() == 4, "quantized cache must be a 4D tensor");
30+
31+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
32+
quantized_value.dim() == 4, "quantized_value must be a 4D tensor");
33+
34+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
35+
start_pos < quantized_cache.size(1),
36+
"start_pos must be less than cache size at dim 1");
37+
38+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
39+
(start_pos + seq_length) <= quantized_cache.size(1),
40+
"start_post + seq_length must be less than max seq length supported by cache."
41+
"start pos: %" PRId64 ", seq_length: %" PRId64
42+
"."
43+
"cache size: %zd",
44+
start_pos,
45+
seq_length,
46+
quantized_cache.size(1));
47+
48+
// Make sure they are in contiguous dim order
49+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
50+
is_contiguous_dim_order(
51+
quantized_cache.dim_order().data(), quantized_cache.dim()),
52+
"quantized cache must be in contiguous dim order");
53+
54+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
55+
is_contiguous_dim_order(
56+
quantized_value.dim_order().data(), quantized_value.dim()),
57+
"quantized value must be in contiguous dim order");
58+
59+
return true;
60+
}
61+
} // anonymous namespace
62+
63+
Tensor& update_quantized_cache_out(
64+
RuntimeContext& ctx,
65+
const Tensor& value,
66+
Tensor& cache,
67+
const int64_t start_pos,
68+
Tensor& output) {
69+
(void)ctx;
70+
int64_t seq_len = value.size(1);
71+
ET_KERNEL_CHECK(
72+
ctx,
73+
validate_cache_params(value, cache, start_pos, seq_len),
74+
InvalidArgument,
75+
output);
76+
77+
ET_CHECK_MSG(
78+
value.size(0) == cache.size(0),
79+
"projected_value batch size should be equal to the cache batch size.");
80+
ET_CHECK_MSG(
81+
value.size(2) == cache.size(2),
82+
"projected_value number of heads should be equal to the cache number of heads.");
83+
ET_CHECK_MSG(
84+
value.size(3) == cache.size(3),
85+
"projected_value embedding dimension should be equal to the cache embedding dimension.");
86+
ET_CHECK_MSG(
87+
value.element_size() == cache.element_size(),
88+
"projected_value data type size should be equal to the cache data type size.");
89+
90+
ET_CHECK_MSG(
91+
is_contiguous_dim_order(value.dim_order().data(), value.dim()),
92+
"projected value must be in contiguous dim order");
93+
ET_CHECK_MSG(
94+
is_contiguous_dim_order(cache.dim_order().data(), cache.dim()),
95+
"projected value must be in contiguous dim order");
96+
97+
const void* value_data = value.const_data_ptr();
98+
void* cache_data = cache.mutable_data_ptr();
99+
100+
ET_CHECK_MSG(value_data, "projected_value data is null");
101+
ET_CHECK_MSG(cache_data, "cache data is null");
102+
103+
auto cache_strides = cache.strides();
104+
exec_aten::StridesType cache_batch_dim_stride = cache_strides[0];
105+
exec_aten::StridesType cache_seq_dim_stride = cache_strides[1];
106+
107+
auto value_strides = value.strides();
108+
exec_aten::StridesType value_batch_dim_stride = value_strides[0];
109+
110+
exec_aten::SizesType num_bytes_to_copy =
111+
(value.numel() / value.size(0)) * value.element_size();
112+
113+
for (int64_t batch_line = 0; batch_line < value.size(0); ++batch_line) {
114+
exec_aten::SizesType cache_pos_offset =
115+
(batch_line * cache_batch_dim_stride +
116+
start_pos * cache_seq_dim_stride) *
117+
cache.element_size();
118+
exec_aten::SizesType value_pos_offset =
119+
(batch_line * value_batch_dim_stride) * cache.element_size();
120+
121+
std::memcpy(
122+
(uint8_t*)cache_data + cache_pos_offset,
123+
(uint8_t*)value_data + value_pos_offset,
124+
num_bytes_to_copy);
125+
}
126+
127+
// Noone uses output. Just a placeholder.
128+
return output;
129+
}
130+
} // namespace native
131+
} // namespace executor
132+
} // namespace torch
133+
134+
// Really this is just an inplace tensor update op
135+
// which makes assumption on the rank of a tensor,
136+
// and the dim order (memory layout) of the tensor.
137+
// Furthermore assumes that the indexing is along
138+
// sequence dimension (dim 1) of the tensor.
139+
// In later diffs will rename this to update_cache.
140+
EXECUTORCH_LIBRARY(
141+
llama,
142+
"update_quantized_cache.out",
143+
torch::executor::native::update_quantized_cache_out);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
16+
namespace native {
17+
18+
Tensor& update_quantized_cache_out(
19+
RuntimeContext& ctx,
20+
const Tensor& value,
21+
Tensor& cache,
22+
const int64_t start_pos,
23+
Tensor& output);
24+
} // namespace native
25+
} // namespace executor
26+
} // namespace torch

extension/llm/custom_ops/sdpa_with_kv_cache.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from torch.library import impl
1919

20+
# TODO rename this file to custom_ops_meta_registration.py
2021
try:
2122
op = torch.ops.llama.sdpa_with_kv_cache.default
2223
assert op is not None
@@ -138,3 +139,54 @@ def fast_hadamard_transform_meta(mat):
138139
# assert(mat.shape[-1] == 128 or mat.shape[-1] == 14336, "unexpected input size for llama3 demo!")
139140
# assert(mat.is_contiguous(), "input matrix must be contiguous currently!")
140141
return torch.empty_like(mat)
142+
143+
144+
def _validate_update_cache_params(
145+
value,
146+
cache,
147+
start_pos,
148+
):
149+
seq_len = value.size(1)
150+
assert (
151+
value.dim() == 4
152+
), f"Expected value to be 4 dimensional but got {value.dim()} dimensions."
153+
154+
assert (
155+
value.dtype == cache.dtype
156+
), f"Expected value and cache to be of the same type but got value type {value.dtype} and cache type {cache.dtype}"
157+
158+
for i in [0, 2, 3]:
159+
assert value.size(i) == cache.size(
160+
i
161+
), f"Expected value and cache to have same size in dimension {i} but got {value.size(i)} and {cache.size(i)}"
162+
163+
torch._check_is_size(start_pos)
164+
# Setting to arbitrary limit of 256 for now since there is no way
165+
# to plumb this information from model config
166+
torch._check(start_pos < cache.size(1))
167+
assert start_pos < cache.size(
168+
1
169+
), f"Start position {start_pos} must be less than sequence length {cache.size(1)}"
170+
171+
torch._check((start_pos + seq_len) < cache.size(1))
172+
assert (start_pos + seq_len) < cache.size(
173+
1
174+
), f"Start position + length = {start_pos + seq_len} must be less than sequence length {cache.size(1)}"
175+
176+
177+
@impl(custom_ops_lib, "update_quantized_cache", "Meta")
178+
def update_quantized_cache_meta(
179+
value,
180+
cache,
181+
start_pos,
182+
):
183+
_validate_update_cache_params(
184+
value,
185+
cache,
186+
start_pos,
187+
)
188+
189+
# Update cache doesnt really return anything but I dont know a better
190+
# workaround. Should we just return cache instead? But I am afraid that
191+
# will result in extra memory allocation
192+
return torch.empty((1,), dtype=value.dtype, device="meta")

extension/llm/custom_ops/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ def define_common_targets():
1313
"op_fallback.cpp",
1414
"op_fast_hadamard_transform.cpp",
1515
"op_sdpa.cpp",
16+
"op_update_quantized_cache.cpp",
1617
],
1718
exported_headers = [
1819
"op_fallback.h",
1920
"op_fast_hadamard_transform.h",
2021
"op_sdpa.h",
22+
"op_update_quantized_cache.h",
2123
],
2224
exported_deps = [
2325
"//executorch/runtime/kernel:kernel_includes",

0 commit comments

Comments
 (0)