Skip to content

Commit ee7ae68

Browse files
authored
Delete unused code
Differential Revision: D61497050 Pull Request resolved: #4780
1 parent 482f60f commit ee7ae68

File tree

1 file changed

+0
-12
lines changed

1 file changed

+0
-12
lines changed

examples/models/llama2/llama_transformer.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -131,18 +131,6 @@ def __post_init__(self):
131131
self.hidden_dim = find_multiple(hidden_dim, multiple_of)
132132

133133

134-
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
135-
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
136-
bs, slen, n_kv_heads, head_dim = x.shape
137-
if n_rep == 1:
138-
return x
139-
return (
140-
x[:, :, :, None, :]
141-
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
142-
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
143-
)
144-
145-
146134
class KVCache(nn.Module):
147135
def __init__(
148136
self,

0 commit comments

Comments
 (0)