Skip to content

Commit 5a18cc6

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Avoid converting k and v to q dtype (#2201)
Summary: Pull Request resolved: #2201 We don't want to hard code k and v into q's dtype, instead we should make sure they are always the same before feeding into sdpa. The problem was due to dtype mismatch between the kv cache used for tracing and the actual weights. This happens in the fp16 flow. After we convert the whole model to fp16, we still use fp32 kv cache tensors for tracing and that's causes the dtype mismatch issue. This diff changes the logic to be using the same dtype as the weights, for kv cache during tracing Reviewed By: mikekgfb Differential Revision: D54426672 fbshipit-source-id: d34009fedc59ebf5ba7ee77e26341a1f99340df6
1 parent 03c056b commit 5a18cc6

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

examples/models/llama2/builder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,15 @@ def to_dtype(self, dtype_override: Optional[DType]) -> "LlamaEdgeManager":
176176
logging.info(f"model.to {torch_dtype}")
177177
self.model = self.model.to(dtype=torch_dtype)
178178
self.dtype = dtype_override
179+
180+
# convert kv cache to dtype as well. This should be removed after mutable buffer is supported.
181+
# assuming the kv cache are the last 2 tensors in the example inputs
182+
if self.use_kv_cache:
183+
dtype = torch.float16 if self.dtype == DType.fp16 else torch.float32
184+
example_inputs = list(self.example_inputs[:-2]) + [
185+
cache.to(dtype) for cache in self.example_inputs[-2:]
186+
]
187+
self.example_inputs = tuple(example_inputs)
179188
return self
180189

181190
def source_transform(

examples/models/llama2/model.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,6 @@ def forward(
296296
# tensor will be 2-dimensional, regarldess of the values of l & s
297297
mask = torch.squeeze(mask, [0, 1])
298298

299-
# FIXME: This should be so automatically! MKG
300-
keys = keys.to(dtype=xq.dtype)
301-
values = values.to(dtype=xq.dtype)
302-
303299
output = F.scaled_dot_product_attention(
304300
xq, keys, values, attn_mask=mask, dropout_p=0.0
305301
)
@@ -672,8 +668,8 @@ def get_example_inputs(self):
672668

673669
def get_example_inputs_kvcache(self):
674670
cache_sizes = self.model_.get_cache_sizes()
675-
cache_k = torch.zeros(cache_sizes)
676-
cache_v = torch.zeros(cache_sizes)
671+
cache_k = torch.zeros(cache_sizes, dtype=self.dtype)
672+
cache_v = torch.zeros(cache_sizes, dtype=self.dtype)
677673
return (
678674
torch.tensor(
679675
[[1]], dtype=torch.long

0 commit comments

Comments
 (0)