@@ -9,6 +9,8 @@ class CustomKVCache(nn.Module):
9
9
def __init__ (self , max_batch_size , max_seq_length , n_heads , head_dim , dtype ):
10
10
super ().__init__ ()
11
11
12
+ dtype = torch .float
13
+
12
14
# This is flipped around from what is in build.model's KVCache
13
15
cache_shape = (max_batch_size , max_seq_length , n_heads , head_dim )
14
16
self .register_buffer (
@@ -21,8 +23,8 @@ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype):
21
23
def update (self , input_pos , k_val , v_val ):
22
24
k_out = self .k_cache
23
25
v_out = self .v_cache
24
- k_out [:, :, input_pos ] = k_val
25
- v_out [:, :, input_pos ] = v_val
26
+ k_out [:, :, input_pos ] = k_val . float ()
27
+ v_out [:, :, input_pos ] = v_val . float ()
26
28
27
29
return k_out , v_out
28
30
@@ -67,15 +69,15 @@ def forward(self, x, freqs_cis, mask, input_pos=None):
67
69
# KV cache should always be enabled
68
70
assert self .kv_cache is not None
69
71
output = torch .ops .llama .sdpa_with_kv_cache (
70
- q ,
71
- k ,
72
- v ,
72
+ q . float () ,
73
+ k . float () ,
74
+ v . float () ,
73
75
self .kv_cache .k_cache ,
74
76
self .kv_cache .v_cache ,
75
77
input_pos [- 1 ].item (),
76
78
seqlen ,
77
79
)
78
- output = output .view (bsz , seqlen , self .dim )
80
+ output = output .view (bsz , seqlen , self .dim ). to ( dtype = q . dtype )
79
81
return self .wo (output )
80
82
81
83
0 commit comments