Skip to content

Commit e1c0815

Browse files
authored
test sdpa with fp16 (#553)
* test sdpa with fp16 * kv cache fp32 * typo
1 parent 33dc210 commit e1c0815

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

export_et_util.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ class CustomKVCache(nn.Module):
99
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype):
1010
super().__init__()
1111

12+
dtype = torch.float
13+
1214
# This is flipped around from what is in build.model's KVCache
1315
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
1416
self.register_buffer(
@@ -21,8 +23,8 @@ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype):
2123
def update(self, input_pos, k_val, v_val):
2224
k_out = self.k_cache
2325
v_out = self.v_cache
24-
k_out[:, :, input_pos] = k_val
25-
v_out[:, :, input_pos] = v_val
26+
k_out[:, :, input_pos] = k_val.float()
27+
v_out[:, :, input_pos] = v_val.float()
2628

2729
return k_out, v_out
2830

@@ -67,15 +69,15 @@ def forward(self, x, freqs_cis, mask, input_pos=None):
6769
# KV cache should always be enabled
6870
assert self.kv_cache is not None
6971
output = torch.ops.llama.sdpa_with_kv_cache(
70-
q,
71-
k,
72-
v,
72+
q.float(),
73+
k.float(),
74+
v.float(),
7375
self.kv_cache.k_cache,
7476
self.kv_cache.v_cache,
7577
input_pos[-1].item(),
7678
seqlen,
7779
)
78-
output = output.view(bsz, seqlen, self.dim)
80+
output = output.view(bsz, seqlen, self.dim).to(dtype=q.dtype)
7981
return self.wo(output)
8082

8183

0 commit comments

Comments
 (0)