@@ -352,26 +352,18 @@ def forward(
352
352
# View + expand + reshape bring num_kv_heads to num_heads for k and v
353
353
# to match q.
354
354
355
- # k: [bsz, seq_len, n_kv, 1, h_d]
356
- # v: [bsz, seq_len, n_kv, 1, h_d]
357
- k = k .view (bsz , - 1 , self .num_kv_heads , 1 , self .head_dim )
358
- v = v .view (bsz , - 1 , self .num_kv_heads , 1 , self .head_dim )
359
-
360
- # Expand the key and value tensors to have the same shape
361
- # as the query tensor by copying values across the relevant dim
362
- if self .num_heads != self .num_kv_heads :
363
- k = k .expand (bsz , - 1 , self .num_kv_heads , self .q_per_kv , self .head_dim )
364
- v = v .expand (bsz , - 1 , self .num_kv_heads , self .q_per_kv , self .head_dim )
365
-
366
- # [bsz, s, n_h, h_d]
367
- k = k .reshape (bsz , - 1 , self .num_heads , self .head_dim )
368
- v = v .reshape (bsz , - 1 , self .num_heads , self .head_dim )
369
-
370
355
# [bsz, n_h, s, h_d]
371
356
q = q .transpose (1 , 2 )
372
357
k = k .transpose (1 , 2 )
373
358
v = v .transpose (1 , 2 )
374
359
360
+ # Expand the key and value tensors to have the same shape
361
+ # as the query tensor by copying values across the relevant dim
362
+ if self .num_heads != self .num_kv_heads :
363
+ expand_shape = (- 1 , - 1 , self .q_per_kv , - 1 , - 1 )
364
+ k = k .unsqueeze (2 ).expand (expand_shape ).flatten (1 , 2 )
365
+ v = v .unsqueeze (2 ).expand (expand_shape ).flatten (1 , 2 )
366
+
375
367
output = self ._attention_fn (
376
368
q ,
377
369
k ,
0 commit comments