@@ -193,6 +193,93 @@ def update(
193
193
return k_out , v_out
194
194
195
195
196
+ class SDPA (nn .Module ):
197
+ def __init__ (
198
+ self ,
199
+ kv_cache : KVCache ,
200
+ mask ,
201
+ use_sdpa_with_kv_cache_op : bool ,
202
+ dim : int ,
203
+ ):
204
+ super ().__init__ ()
205
+ self .kv_cache = kv_cache
206
+ self .mask = mask
207
+ self .use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
208
+ self .dim = dim
209
+
210
+ def forward (
211
+ self ,
212
+ input_pos : torch .Tensor ,
213
+ q : torch .Tensor ,
214
+ k : torch .Tensor ,
215
+ v : torch .Tensor ,
216
+ bsz ,
217
+ seqlen ,
218
+ ) -> torch .Tensor :
219
+ if not self .use_sdpa_with_kv_cache_op :
220
+ return self ._forward_default (
221
+ input_pos ,
222
+ q ,
223
+ k ,
224
+ v ,
225
+ bsz ,
226
+ seqlen ,
227
+ )
228
+ else :
229
+ return self ._forward_custom (
230
+ input_pos ,
231
+ q ,
232
+ k ,
233
+ v ,
234
+ bsz ,
235
+ seqlen ,
236
+ )
237
+
238
+ def _forward_custom (
239
+ self ,
240
+ input_pos : torch .Tensor ,
241
+ q : torch .Tensor ,
242
+ k : torch .Tensor ,
243
+ v : torch .Tensor ,
244
+ bsz ,
245
+ seqlen ,
246
+ ):
247
+ from .custom_ops import sdpa_with_kv_cache # noqa
248
+
249
+ output = torch .ops .llama .sdpa_with_kv_cache (
250
+ q ,
251
+ k ,
252
+ v ,
253
+ self .kv_cache .k_cache ,
254
+ self .kv_cache .v_cache ,
255
+ input_pos [- 1 ].item (),
256
+ seqlen ,
257
+ )
258
+ return output .view (bsz , seqlen , self .dim )
259
+
260
+ def _forward_default (
261
+ self ,
262
+ input_pos : torch .Tensor ,
263
+ q : torch .Tensor ,
264
+ k : torch .Tensor ,
265
+ v : torch .Tensor ,
266
+ bsz ,
267
+ seqlen ,
268
+ ) -> torch .Tensor :
269
+ q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
270
+ k = k .transpose (1 , 2 )
271
+ v = v .transpose (1 , 2 )
272
+
273
+ k , v = self .kv_cache .update (input_pos , k , v )
274
+ mask = self .mask [None , None , input_pos ]
275
+
276
+ k = k .repeat_interleave (self .n_rep , dim = 1 )
277
+ v = v .repeat_interleave (self .n_rep , dim = 1 )
278
+ y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
279
+
280
+ return y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
281
+
282
+
196
283
class Attention (nn .Module ):
197
284
def __init__ (self , args : ModelArgs , layer_id : int ):
198
285
super ().__init__ ()
@@ -213,7 +300,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
213
300
self .wv = nn .Linear (args .dim , self .n_kv_heads * self .head_dim , bias = False )
214
301
self .wo = nn .Linear (args .n_heads * self .head_dim , args .dim , bias = False )
215
302
216
- self .use_sdpa_with_kv_cache_op = args .use_sdpa_with_kv_cache_op
217
303
self .layer_id = layer_id
218
304
219
305
causal_mask = torch .tril (
@@ -234,6 +320,12 @@ def __init__(self, args: ModelArgs, layer_id: int):
234
320
self .head_dim ,
235
321
not args .use_sdpa_with_kv_cache_op , # if we are using the custom op dont transpose the cache. Expect untransposed q k v
236
322
)
323
+ self .SDPA = SDPA (
324
+ self .kv_cache ,
325
+ self .mask ,
326
+ args .use_sdpa_with_kv_cache_op ,
327
+ self .dim ,
328
+ )
237
329
238
330
def forward (
239
331
self ,
@@ -256,41 +348,8 @@ def forward(
256
348
257
349
if self .use_kv_cache :
258
350
assert input_pos is not None
259
-
260
- if not self .use_sdpa_with_kv_cache_op :
261
-
262
- q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
263
- k = k .transpose (1 , 2 )
264
- v = v .transpose (1 , 2 )
265
-
266
- k , v = self .kv_cache .update (input_pos , k , v )
267
- mask = self .mask [None , None , input_pos ]
268
-
269
- k = k .repeat_interleave (self .n_rep , dim = 1 )
270
- v = v .repeat_interleave (self .n_rep , dim = 1 )
271
- y = F .scaled_dot_product_attention (
272
- q , k , v , attn_mask = mask , dropout_p = 0.0
273
- )
274
-
275
- y = y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
276
-
277
- y = self .wo (y )
278
- return y
279
- else :
280
- from .custom_ops import sdpa_with_kv_cache # noqa
281
-
282
- output = torch .ops .llama .sdpa_with_kv_cache (
283
- q ,
284
- k ,
285
- v ,
286
- self .kv_cache .k_cache ,
287
- self .kv_cache .v_cache ,
288
- input_pos [- 1 ].item (),
289
- seqlen ,
290
- )
291
- output = output .view (bsz , seqlen , - 1 )
292
- output = self .wo (output )
293
- return output
351
+ output = self .SDPA (input_pos , q , k , v , bsz , seqlen )
352
+ return self .wo (output )
294
353
295
354
q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
296
355
k = k .transpose (1 , 2 )
0 commit comments