Skip to content

Commit 318f65e

Browse files
authored
move rope related logic together
Differential Revision: D65173598 Pull Request resolved: #6560
1 parent 85698df commit 318f65e

File tree

2 files changed

+101
-66
lines changed

2 files changed

+101
-66
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 85 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,81 @@ def __post_init__(self):
147147
self.head_dim = self.dim // self.n_heads
148148

149149

150+
class Rope(torch.nn.Module):
151+
def __init__(self, params: ModelArgs):
152+
super().__init__()
153+
self.params = params
154+
if self.params.use_hf_rope:
155+
self.precompute_freqs_cis = hf_precompute_freqs_cis
156+
else:
157+
self.precompute_freqs_cis = partial(
158+
precompute_freqs_cis, use_scaled=self.params.use_scaled_rope
159+
)
160+
freqs_cos, freqs_sin = self.precompute_freqs_cis(
161+
self.params.head_dim,
162+
(
163+
self.params.max_seq_len # Normal llama2.
164+
if self.params.ffn_dim_multiplier is None
165+
else self.params.max_seq_len * 2 # Sharded checkpoint.
166+
),
167+
self.params.rope_freq_base,
168+
)
169+
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
170+
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
171+
if self.params.use_hf_rope:
172+
self.apply_rotary_emb = hf_apply_rotary_emb
173+
else:
174+
self.apply_rotary_emb = RotaryEmbedding()
175+
176+
def forward(
177+
self,
178+
q: torch.Tensor,
179+
k: torch.Tensor,
180+
freqs_cos: torch.Tensor,
181+
freqs_sin: torch.Tensor,
182+
):
183+
return self.apply_rotary_emb(q, k, freqs_cos, freqs_sin)
184+
185+
def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):
186+
"""
187+
Get the precomputed frequencies for the given input position and sequence length.
188+
189+
Args:
190+
input_pos (torch.Tensor): The input position tensor.
191+
seq_len (int): The sequence length.
192+
193+
Returns:
194+
Tuple[torch.Tensor, torch.Tensor]: The precomputed frequencies for the given input position and sequence length.
195+
"""
196+
if self.params.use_kv_cache:
197+
assert (
198+
input_pos is not None
199+
), "input_pos must be provided when use_kv_cache is True"
200+
201+
if self.params.enable_dynamic_shape:
202+
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
203+
input_pos_item = input_pos[-1].item()
204+
torch._check_is_size(input_pos_item)
205+
torch._check(input_pos_item < self.params.max_seq_len)
206+
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
207+
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len)
208+
# pyre-ignore: Incompatible parameter type [6]
209+
freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len)
210+
else:
211+
# When not using dynamic shape, use of the .item results in
212+
# symints, due to querying the data from tensor.
213+
# this path avoids that for mps backend, although probably mps backend
214+
# can support dynamic shape?
215+
freqs_cos = self.freqs_cos[input_pos]
216+
freqs_sin = self.freqs_sin[input_pos]
217+
218+
else:
219+
assert input_pos is None, "input_pos is unused when use_kv_cache is False"
220+
freqs_cos = self.freqs_cos[:seq_len]
221+
freqs_sin = self.freqs_sin[:seq_len]
222+
return freqs_cos, freqs_sin
223+
224+
150225
class KVCache(nn.Module):
151226
def __init__(
152227
self,
@@ -266,7 +341,7 @@ def forward(
266341

267342

268343
class Attention(nn.Module):
269-
def __init__(self, args: ModelArgs, layer_id: int):
344+
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
270345
super().__init__()
271346
self.use_kv_cache = args.use_kv_cache
272347
self.n_heads = args.n_heads
@@ -287,6 +362,8 @@ def __init__(self, args: ModelArgs, layer_id: int):
287362

288363
self.layer_id = layer_id
289364

365+
self.rope = rope
366+
290367
causal_mask = torch.tril(
291368
torch.ones(
292369
self.max_seq_len,
@@ -303,7 +380,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
303380
args.max_seq_len,
304381
self.n_kv_heads,
305382
self.head_dim,
306-
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v
383+
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v
307384
args.enable_dynamic_shape,
308385
)
309386
self.SDPA = SDPA(
@@ -314,10 +391,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
314391
max_seq_len=self.max_seq_len,
315392
enable_dynamic_shape=args.enable_dynamic_shape,
316393
)
317-
if args.use_hf_rope:
318-
self.apply_rotary_emb = hf_apply_rotary_emb
319-
else:
320-
self.apply_rotary_emb = RotaryEmbedding()
321394

322395
def forward(
323396
self,
@@ -336,7 +409,7 @@ def forward(
336409
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
337410

338411
# RoPE relative positional embeddings
339-
q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin)
412+
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
340413

341414
if self.use_kv_cache:
342415
assert input_pos is not None
@@ -424,13 +497,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
424497

425498

426499
class TransformerBlock(nn.Module):
427-
def __init__(self, layer_id: int, args: ModelArgs):
500+
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
428501
super().__init__()
429502
self.use_kv_cache = args.use_kv_cache
430503
self.n_heads = args.n_heads
431504
self.dim = args.dim
432505
self.head_dim = args.head_dim
433-
self.attention = Attention(args, layer_id)
506+
self.attention = Attention(args, layer_id, rope)
434507
if args.moe:
435508
self.block_sparse_moe = MOEFeedForward(args)
436509
else:
@@ -459,33 +532,17 @@ def __init__(self, params: ModelArgs):
459532
self.n_layers = params.n_layers
460533

461534
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
535+
self.rope = Rope(params)
462536
self.layers = torch.nn.ModuleList()
463537
for layer_id in range(params.n_layers):
464-
self.layers.append(TransformerBlock(layer_id, params))
538+
self.layers.append(TransformerBlock(layer_id, params, self.rope))
465539
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
466540
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
467541
self.use_kv_cache = params.use_kv_cache
468542
self.generate_full_logits = params.generate_full_logits
469543
self.max_seq_len = params.max_seq_len
470544
self.input_prune_map = params.input_prune_map
471545
self.output_prune_map = params.output_prune_map
472-
if params.use_hf_rope:
473-
self.precompute_freqs_cis = hf_precompute_freqs_cis
474-
else:
475-
self.precompute_freqs_cis = partial(
476-
precompute_freqs_cis, use_scaled=params.use_scaled_rope
477-
)
478-
freqs_cos, freqs_sin = self.precompute_freqs_cis(
479-
params.head_dim,
480-
(
481-
params.max_seq_len # Normal llama2.
482-
if params.ffn_dim_multiplier is None
483-
else params.max_seq_len * 2 # Sharded checkpoint.
484-
),
485-
params.rope_freq_base,
486-
)
487-
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
488-
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
489546

490547
def forward(
491548
self,
@@ -502,33 +559,7 @@ def forward(
502559
if tokens is not None and h is None:
503560
h = self.tok_embeddings(tokens)
504561
seqlen = h.shape[1]
505-
506-
if self.use_kv_cache:
507-
assert (
508-
input_pos is not None
509-
), "input_pos must be provided when use_kv_cache is True"
510-
511-
if self.params.enable_dynamic_shape:
512-
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
513-
input_pos_item = input_pos[-1].item()
514-
torch._check_is_size(input_pos_item)
515-
torch._check(input_pos_item < self.params.max_seq_len)
516-
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
517-
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen)
518-
# pyre-ignore: Incompatible parameter type [6]
519-
freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen)
520-
else:
521-
# When not using dynamic shape, use of the .item results in
522-
# symints, due to querying the data from tensor.
523-
# this path avoids that for mps backend, although probably mps backend
524-
# can support dynamic shape?
525-
freqs_cos = self.freqs_cos[input_pos]
526-
freqs_sin = self.freqs_sin[input_pos]
527-
528-
else:
529-
assert input_pos is None, "input_pos is unused when use_kv_cache is False"
530-
freqs_cos = self.freqs_cos[:seqlen]
531-
freqs_sin = self.freqs_sin[:seqlen]
562+
freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen)
532563

533564
for layer in self.layers:
534565
h = layer(

examples/models/llama/source_transformation/rope.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,27 @@ def materialze_broadcast_of_rope_freq_cis(
1313
module: torch.nn.Module,
1414
):
1515
assert isinstance(module, Transformer)
16-
assert module.freqs_cos.dim() == 2
17-
dim0 = module.freqs_cos.size(0)
18-
dim1 = module.freqs_cos.size(1)
16+
assert module.rope.freqs_cos.dim() == 2
17+
dim0 = module.rope.freqs_cos.size(0)
18+
dim1 = module.rope.freqs_cos.size(1)
1919
module_attention = module.layers[0].attention
2020
assert (
2121
module_attention.n_local_kv_heads == module_attention.n_local_heads
2222
), f"For rope freqs to be materialized for broadcast, q, k, v num heads must match. For q got {module_attention.n_kv_heads} for k got {module_attention.n_local_heads} and v got {module_attention.n_local_kv_heads}"
2323
num_heads = module_attention.n_local_heads
24-
module.freqs_cos = module.freqs_cos.view(dim0, 1, dim1)
25-
module.freqs_cos = module.freqs_cos.expand(dim0, num_heads, dim1).contiguous()
26-
assert module.freqs_sin.dim() == 2
27-
assert dim0 == module.freqs_sin.size(
24+
module.rope.freqs_cos = module.rope.freqs_cos.view(dim0, 1, dim1)
25+
module.rope.freqs_cos = module.rope.freqs_cos.expand(
26+
dim0, num_heads, dim1
27+
).contiguous()
28+
assert module.rope.freqs_sin.dim() == 2
29+
assert dim0 == module.rope.freqs_sin.size(
2830
0
29-
), f"sin and cos freq table sizes must match. Mismatch found at dim 0: {dim0} vs {module.freqs_sin.size(0)}"
30-
assert dim1 == module.freqs_sin.size(
31+
), f"sin and cos freq table sizes must match. Mismatch found at dim 0: {dim0} vs {module.rope.freqs_sin.size(0)}"
32+
assert dim1 == module.rope.freqs_sin.size(
3133
1
32-
), f"sin and cos freq table sizes must match. Mismatch found at dim 1: {dim1} vs {module.freqs_sin.size(1)}"
33-
module.freqs_sin = module.freqs_sin.view(dim0, 1, dim1)
34-
module.freqs_sin = module.freqs_sin.expand(dim0, num_heads, dim1).contiguous()
34+
), f"sin and cos freq table sizes must match. Mismatch found at dim 1: {dim1} vs {module.rope.freqs_sin.size(1)}"
35+
module.rope.freqs_sin = module.rope.freqs_sin.view(dim0, 1, dim1)
36+
module.rope.freqs_sin = module.rope.freqs_sin.expand(
37+
dim0, num_heads, dim1
38+
).contiguous()
3539
return module

0 commit comments

Comments
 (0)