@@ -37,8 +37,6 @@ def __init__(
37
37
n_heads ,
38
38
head_dim ,
39
39
cache_type : QuantizedCacheType = QuantizedCacheType .AffineSymmetric ,
40
- tranposed = False ,
41
- enable_dynamic_shape = False ,
42
40
):
43
41
super ().__init__ ()
44
42
if cache_type not in (
@@ -52,14 +50,8 @@ def __init__(
52
50
# For now supporting int8 only
53
51
self .quantized_cache_dtype = torch .int8
54
52
self .cache_fp_type = torch .float32
55
- self .is_transposed = tranposed
56
- self .enable_dynamic_shape = enable_dynamic_shape
57
- if self .is_transposed :
58
- cache_shape = (max_batch_size , n_heads , max_seq_length , head_dim )
59
- scale_shape = (max_batch_size , n_heads , max_seq_length , 1 )
60
- else :
61
- cache_shape = (max_batch_size , max_seq_length , n_heads , head_dim )
62
- scale_shape = (max_batch_size , max_seq_length , n_heads , 1 )
53
+ cache_shape = (max_batch_size , max_seq_length , n_heads , head_dim )
54
+ scale_shape = (max_batch_size , max_seq_length , n_heads , 1 )
63
55
self .register_buffer (
64
56
"k_cache" , torch .zeros (cache_shape , dtype = self .quantized_cache_dtype )
65
57
)
@@ -98,71 +90,37 @@ def _quantize(self, value):
98
90
return quantized_value , scales , zero_points
99
91
100
92
def update (self , input_pos , k_val , v_val ):
93
+ """
94
+ k_val, v_val: [B, H, S, D]
95
+ return: [B, H, S, D]
96
+ However the storage is [B, S, H, D] so we incur transpose in, transpose out
97
+ This shall be removed by subsequent post-export graph pass
98
+ """
99
+ k_val = k_val .transpose (1 , 2 )
100
+ v_val = v_val .transpose (1 , 2 )
101
101
# quantize current k_val and store it in the cache
102
102
quantized_k_val , k_scales , k_zero_points = self ._quantize (k_val )
103
103
104
104
quantized_v_val , v_scales , v_zero_points = self ._quantize (v_val )
105
105
106
- if self .is_transposed :
107
- # We cannot use update_cache op at the moment
108
- # if the cache is transposed
109
- # Also note that we shold not need separate paths
110
- # for dynamic shape vs !
111
- # Only reason it is done this way is to accommodate
112
- # for lowering pains of backends that work better
113
- # with index_put op.
114
- if self .enable_dynamic_shape :
115
- start_pos = input_pos [0 ].item ()
116
- torch ._check_is_size (start_pos )
117
- dim_to_slice = 2 if self .is_transposed else 1
118
- torch ._check (start_pos < self .k_cache .size (dim_to_slice ))
119
- seq_length = k_val .size (dim_to_slice )
120
- narrowed_k = self .k_cache .narrow (dim_to_slice , start_pos , seq_length )
121
- narrowed_k_scales = self .k_cache_scales .narrow (
122
- dim_to_slice , start_pos , seq_length
123
- )
124
- narrowed_k_zp = self .k_cache_zero_points .narrow (
125
- dim_to_slice , start_pos , seq_length
126
- )
127
- narrowed_k .copy_ (quantized_k_val )
128
- narrowed_k_scales .copy_ (k_scales )
129
- narrowed_k_zp .copy_ (k_zero_points )
130
- narrowed_v = self .v_cache .narrow (dim_to_slice , start_pos , seq_length )
131
- narrowed_v_scales = self .v_cache_scales .narrow (
132
- dim_to_slice , start_pos , seq_length
133
- )
134
- narrowed_v_zp = self .v_cache_zero_points .narrow (
135
- dim_to_slice , start_pos , seq_length
136
- )
137
- narrowed_v .copy_ (quantized_v_val )
138
- narrowed_v_scales .copy_ (v_scales )
139
- narrowed_v_zp .copy_ (v_zero_points )
140
- else :
141
- self .k_cache [:, :, input_pos ] = quantized_k_val
142
- self .k_cache_scales [:, :, input_pos ] = k_scales
143
- self .k_cache_zero_points [:, :, input_pos ] = k_zero_points
144
- self .v_cache [:, :, input_pos ] = quantized_v_val
145
- self .v_cache_scales [:, :, input_pos ] = v_scales
146
- self .v_cache_zero_points [:, :, input_pos ] = v_zero_points
147
- else :
148
- # Right now using custom ops on this path.
149
- # In future we can update custom op to handle transposed cache
150
- # as well.
151
- # Note that we may have to revert this change if other ET
152
- # backends such as QNN want to use quantized cache, with dynamic shape,
153
- # instead of quantizing on their own.
154
- # But until this opting for code simplicity
155
- start_pos = input_pos [0 ].item ()
156
- _ = torch .ops .llama .update_cache (quantized_k_val , self .k_cache , start_pos )
157
- _ = torch .ops .llama .update_cache (k_scales , self .k_cache_scales , start_pos )
158
- _ = torch .ops .llama .update_cache (
159
- k_zero_points , self .k_cache_zero_points , start_pos
160
- )
161
- _ = torch .ops .llama .update_cache (quantized_v_val , self .v_cache , start_pos )
162
- _ = torch .ops .llama .update_cache (v_scales , self .v_cache_scales , start_pos )
163
- _ = torch .ops .llama .update_cache (
164
- v_zero_points , self .v_cache_zero_points , start_pos
165
- )
106
+ # Right now using custom ops on this path.
107
+ # In future we can update custom op to handle transposed cache
108
+ # as well.
109
+ # Note that we may have to revert this change if other ET
110
+ # backends such as QNN want to use quantized cache, with dynamic shape,
111
+ # instead of quantizing on their own.
112
+ # But until this opting for code simplicity
113
+ start_pos = input_pos [0 ].item ()
114
+ _ = torch .ops .llama .update_cache (quantized_k_val , self .k_cache , start_pos )
115
+ _ = torch .ops .llama .update_cache (k_scales , self .k_cache_scales , start_pos )
116
+ _ = torch .ops .llama .update_cache (
117
+ k_zero_points , self .k_cache_zero_points , start_pos
118
+ )
119
+ _ = torch .ops .llama .update_cache (quantized_v_val , self .v_cache , start_pos )
120
+ _ = torch .ops .llama .update_cache (v_scales , self .v_cache_scales , start_pos )
121
+ _ = torch .ops .llama .update_cache (
122
+ v_zero_points , self .v_cache_zero_points , start_pos
123
+ )
166
124
167
125
k_out = torch .ops .quantized_decomposed .dequantize_per_token (
168
126
self .k_cache ,
@@ -183,42 +141,24 @@ def update(self, input_pos, k_val, v_val):
183
141
self .cache_fp_type ,
184
142
)
185
143
186
- if self .is_transposed :
187
- if self .enable_dynamic_shape :
188
- start_pos = input_pos [0 ].item ()
189
- torch ._check_is_size (start_pos )
190
- dim_to_slice = 2 if self .is_transposed else 1
191
- torch ._check (start_pos < self .k_cache .size (dim_to_slice ))
192
- seq_length = k_val .size (dim_to_slice )
193
- narrowed_k = k_out .narrow (dim_to_slice , start_pos , seq_length )
194
- narrowed_k .copy_ (k_val )
195
- narrowed_v = v_out .narrow (dim_to_slice , start_pos , seq_length )
196
- narrowed_v .copy_ (v_val )
197
- else :
198
- k_out [:, :, input_pos ] = k_val
199
- v_out [:, :, input_pos ] = v_val
200
- else :
201
- start_pos = input_pos [0 ].item ()
202
- _ = torch .ops .llama .update_cache (k_val , k_out , start_pos )
203
- _ = torch .ops .llama .update_cache (v_val , v_out , start_pos )
144
+ start_pos = input_pos [0 ].item ()
145
+ _ = torch .ops .llama .update_cache (k_val , k_out , start_pos )
146
+ _ = torch .ops .llama .update_cache (v_val , v_out , start_pos )
204
147
205
- return k_out , v_out
148
+ return k_out . transpose ( 1 , 2 ), v_out . transpose ( 1 , 2 )
206
149
207
150
@classmethod
208
151
def from_float (cls , kv_cache , cache_type : QuantizedCacheType ):
209
- cache_shape = kv_cache .k_cache .shape
210
- if kv_cache .is_transposed :
211
- max_batch_size , n_heads , max_seq_length , head_dim = cache_shape
212
- else :
213
- max_batch_size , max_seq_length , n_heads , head_dim = cache_shape
152
+ max_batch_size , n_heads , max_seq_length , head_dim = kv_cache .k_cache .shape
153
+ if isinstance (kv_cache , CustomKVCache ):
154
+ # If replacing custom kv cache, then the shape is [B, S, H, D]
155
+ max_batch_size , max_seq_length , n_heads , head_dim = kv_cache .k_cache .shape
214
156
return cls (
215
157
max_batch_size ,
216
158
max_seq_length ,
217
159
n_heads ,
218
160
head_dim ,
219
161
cache_type ,
220
- kv_cache .is_transposed ,
221
- kv_cache .enable_dynamic_shape ,
222
162
)
223
163
224
164
@@ -254,7 +194,7 @@ def replace_kv_cache_with_quantized_kv_cache(module):
254
194
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
255
195
)
256
196
for name , child in module .named_children ():
257
- if isinstance (child , KVCache ):
197
+ if isinstance (child , KVCache ) or isinstance ( child , CustomKVCache ) :
258
198
setattr (
259
199
module ,
260
200
name ,
@@ -291,11 +231,13 @@ def __init__(
291
231
def update (
292
232
self , input_pos : torch .Tensor , k_val : torch .Tensor , v_val : torch .Tensor
293
233
) -> Tuple [torch .Tensor , torch .Tensor ]:
294
- # input_pos: [S], k_val: [B, S, H, D]
234
+ # input_pos: [S], k_val: [B, H, S, D]
235
+ k_val = k_val .transpose (1 , 2 )
236
+ v_val = v_val .transpose (1 , 2 )
295
237
start_pos = input_pos [0 ].item ()
296
238
_ = torch .ops .llama .update_cache (k_val , self .k_cache , start_pos )
297
239
_ = torch .ops .llama .update_cache (v_val , self .v_cache , start_pos )
298
- return self .k_cache , self .v_cache
240
+ return self .k_cache . transpose ( 1 , 2 ), self .v_cache . transpose ( 1 , 2 )
299
241
300
242
301
243
def replace_kv_cache_with_custom_kv_cache (module ):
@@ -313,10 +255,7 @@ def replace_kv_cache_with_custom_kv_cache(module):
313
255
if isinstance (child , KVCache ):
314
256
cache_shape = child .k_cache .shape
315
257
cache_dtype = child .k_cache .dtype
316
- assert (
317
- child .is_transposed is False
318
- ), "CustomKVCache does not support transposed cache"
319
- max_batch_size , max_seq_length , n_heads , head_dim = cache_shape
258
+ max_batch_size , n_heads , max_seq_length , head_dim = cache_shape
320
259
setattr (
321
260
module ,
322
261
name ,
0 commit comments