@@ -209,6 +209,95 @@ def update(
209
209
return k_out , v_out
210
210
211
211
212
+ class SDPA (nn .Module ):
213
+ def __init__ (
214
+ self ,
215
+ kv_cache : KVCache ,
216
+ mask ,
217
+ use_sdpa_with_kv_cache_op : bool ,
218
+ dim : int ,
219
+ n_rep : int ,
220
+ ):
221
+ super ().__init__ ()
222
+ self .kv_cache = kv_cache
223
+ self .mask = mask
224
+ self .use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
225
+ self .dim = dim
226
+ self .n_rep = n_rep
227
+
228
+ def forward (
229
+ self ,
230
+ input_pos : torch .Tensor ,
231
+ q : torch .Tensor ,
232
+ k : torch .Tensor ,
233
+ v : torch .Tensor ,
234
+ bsz ,
235
+ seqlen ,
236
+ ) -> torch .Tensor :
237
+ if not self .use_sdpa_with_kv_cache_op :
238
+ return self ._forward_default (
239
+ input_pos ,
240
+ q ,
241
+ k ,
242
+ v ,
243
+ bsz ,
244
+ seqlen ,
245
+ )
246
+ else :
247
+ return self ._forward_custom (
248
+ input_pos ,
249
+ q ,
250
+ k ,
251
+ v ,
252
+ bsz ,
253
+ seqlen ,
254
+ )
255
+
256
+ def _forward_custom (
257
+ self ,
258
+ input_pos : torch .Tensor ,
259
+ q : torch .Tensor ,
260
+ k : torch .Tensor ,
261
+ v : torch .Tensor ,
262
+ bsz ,
263
+ seqlen ,
264
+ ):
265
+ from .custom_ops import sdpa_with_kv_cache # noqa
266
+
267
+ output = torch .ops .llama .sdpa_with_kv_cache (
268
+ q ,
269
+ k ,
270
+ v ,
271
+ self .kv_cache .k_cache ,
272
+ self .kv_cache .v_cache ,
273
+ input_pos [- 1 ].item (),
274
+ seqlen ,
275
+ )
276
+ return output .view (bsz , seqlen , self .dim )
277
+
278
+ def _forward_default (
279
+ self ,
280
+ input_pos : torch .Tensor ,
281
+ q : torch .Tensor ,
282
+ k : torch .Tensor ,
283
+ v : torch .Tensor ,
284
+ bsz ,
285
+ seqlen ,
286
+ ) -> torch .Tensor :
287
+ q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
288
+ k = k .transpose (1 , 2 )
289
+ v = v .transpose (1 , 2 )
290
+
291
+ k , v = self .kv_cache .update (input_pos , k , v )
292
+ mask = self .mask [None , None , input_pos ]
293
+
294
+ k = k .repeat_interleave (self .n_rep , dim = 1 )
295
+ v = v .repeat_interleave (self .n_rep , dim = 1 )
296
+ y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
297
+
298
+ return y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
299
+
300
+
212
301
class Attention (nn .Module ):
213
302
def __init__ (self , args : ModelArgs , layer_id : int ):
214
303
super ().__init__ ()
@@ -229,7 +318,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
229
318
self .wv = nn .Linear (args .dim , self .n_kv_heads * self .head_dim , bias = False )
230
319
self .wo = nn .Linear (args .n_heads * self .head_dim , args .dim , bias = False )
231
320
232
- self .use_sdpa_with_kv_cache_op = args .use_sdpa_with_kv_cache_op
233
321
self .layer_id = layer_id
234
322
235
323
causal_mask = torch .tril (
@@ -250,6 +338,13 @@ def __init__(self, args: ModelArgs, layer_id: int):
250
338
self .head_dim ,
251
339
not args .use_sdpa_with_kv_cache_op , # if we are using the custom op dont transpose the cache. Expect untransposed q k v
252
340
)
341
+ self .SDPA = SDPA (
342
+ self .kv_cache ,
343
+ self .mask ,
344
+ args .use_sdpa_with_kv_cache_op ,
345
+ self .dim ,
346
+ self .n_rep ,
347
+ )
253
348
254
349
def forward (
255
350
self ,
@@ -272,41 +367,8 @@ def forward(
272
367
273
368
if self .use_kv_cache :
274
369
assert input_pos is not None
275
-
276
- if not self .use_sdpa_with_kv_cache_op :
277
-
278
- q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
279
- k = k .transpose (1 , 2 )
280
- v = v .transpose (1 , 2 )
281
-
282
- k , v = self .kv_cache .update (input_pos , k , v )
283
- mask = self .mask [None , None , input_pos ]
284
-
285
- k = k .repeat_interleave (self .n_rep , dim = 1 )
286
- v = v .repeat_interleave (self .n_rep , dim = 1 )
287
- y = F .scaled_dot_product_attention (
288
- q , k , v , attn_mask = mask , dropout_p = 0.0
289
- )
290
-
291
- y = y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
292
-
293
- y = self .wo (y )
294
- return y
295
- else :
296
- from .custom_ops import sdpa_with_kv_cache # noqa
297
-
298
- output = torch .ops .llama .sdpa_with_kv_cache (
299
- q ,
300
- k ,
301
- v ,
302
- self .kv_cache .k_cache ,
303
- self .kv_cache .v_cache ,
304
- input_pos [- 1 ].item (),
305
- seqlen ,
306
- )
307
- output = output .view (bsz , seqlen , - 1 )
308
- output = self .wo (output )
309
- return output
370
+ output = self .SDPA (input_pos , q , k , v , bsz , seqlen )
371
+ return self .wo (output )
310
372
311
373
q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
312
374
k = k .transpose (1 , 2 )
0 commit comments