Skip to content

Commit f2ad9d0

Browse files
committed
move rope related logic together
Right now, rope related code scatters around a few different places in `llama_transformer`. It makes it hard to make changes to rope related things. This PR moves all rope related logic into its own module. Differential Revision: [D65173598](https://our.internmc.facebook.com/intern/diff/D65173598/) [ghstack-poisoned]
1 parent 2c32bf3 commit f2ad9d0

File tree

1 file changed

+75
-67
lines changed

1 file changed

+75
-67
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 75 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,69 @@ def __post_init__(self):
143143
self.hidden_dim = find_multiple(hidden_dim, multiple_of)
144144

145145

146+
class Rope(torch.nn.Module):
147+
def __init__(self, params: ModelArgs):
148+
super().__init__()
149+
self.params = params
150+
if self.params.use_hf_rope:
151+
self.precompute_freqs_cis = hf_precompute_freqs_cis
152+
else:
153+
self.precompute_freqs_cis = partial(
154+
precompute_freqs_cis, use_scaled=self.params.use_scaled_rope
155+
)
156+
freqs_cos, freqs_sin = self.precompute_freqs_cis(
157+
self.params.dim // self.params.n_heads,
158+
(
159+
self.params.max_seq_len # Normal llama2.
160+
if self.params.ffn_dim_multiplier is None
161+
else self.params.max_seq_len * 2 # Sharded checkpoint.
162+
),
163+
self.params.rope_freq_base,
164+
)
165+
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
166+
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
167+
if self.params.use_hf_rope:
168+
self.apply_rotary_emb = hf_apply_rotary_emb
169+
else:
170+
self.apply_rotary_emb = RotaryEmbedding()
171+
172+
def forward(
173+
self,
174+
q: torch.Tensor,
175+
k: torch.Tensor,
176+
seq_len: int,
177+
input_pos: Optional[torch.Tensor] = None,
178+
):
179+
if self.params.use_kv_cache:
180+
assert (
181+
input_pos is not None
182+
), "input_pos must be provided when use_kv_cache is True"
183+
184+
if self.params.enable_dynamic_shape:
185+
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
186+
input_pos_item = input_pos[-1].item()
187+
torch._check_is_size(input_pos_item)
188+
torch._check(input_pos_item < self.params.max_seq_len)
189+
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
190+
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len)
191+
# pyre-ignore: Incompatible parameter type [6]
192+
freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len)
193+
else:
194+
# When not using dynamic shape, use of the .item results in
195+
# symints, due to querying the data from tensor.
196+
# this path avoids that for mps backend, although probably mps backend
197+
# can support dynamic shape?
198+
freqs_cos = self.freqs_cos[input_pos]
199+
freqs_sin = self.freqs_sin[input_pos]
200+
201+
else:
202+
assert input_pos is None, "input_pos is unused when use_kv_cache is False"
203+
freqs_cos = self.freqs_cos[:seq_len]
204+
freqs_sin = self.freqs_sin[:seq_len]
205+
q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin)
206+
return q, k
207+
208+
146209
class KVCache(nn.Module):
147210
def __init__(
148211
self,
@@ -262,7 +325,7 @@ def forward(
262325

263326

264327
class Attention(nn.Module):
265-
def __init__(self, args: ModelArgs, layer_id: int):
328+
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
266329
super().__init__()
267330
self.use_kv_cache = args.use_kv_cache
268331
self.n_heads = args.n_heads
@@ -284,6 +347,8 @@ def __init__(self, args: ModelArgs, layer_id: int):
284347

285348
self.layer_id = layer_id
286349

350+
self.rope = rope
351+
287352
causal_mask = torch.tril(
288353
torch.ones(
289354
self.max_seq_len,
@@ -300,7 +365,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
300365
args.max_seq_len,
301366
self.n_kv_heads,
302367
self.head_dim,
303-
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v
368+
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v
304369
args.enable_dynamic_shape,
305370
)
306371
self.SDPA = SDPA(
@@ -311,16 +376,10 @@ def __init__(self, args: ModelArgs, layer_id: int):
311376
max_seq_len=self.max_seq_len,
312377
enable_dynamic_shape=args.enable_dynamic_shape,
313378
)
314-
if args.use_hf_rope:
315-
self.apply_rotary_emb = hf_apply_rotary_emb
316-
else:
317-
self.apply_rotary_emb = RotaryEmbedding()
318379

319380
def forward(
320381
self,
321382
x: torch.Tensor,
322-
freqs_cos: torch.Tensor,
323-
freqs_sin: torch.Tensor,
324383
input_pos: Optional[torch.Tensor] = None,
325384
):
326385
bsz, seqlen, _ = x.shape
@@ -333,7 +392,7 @@ def forward(
333392
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
334393

335394
# RoPE relative positional embeddings
336-
q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin)
395+
q, k = self.rope.forward(q, k, seqlen, input_pos)
337396

338397
if self.use_kv_cache:
339398
assert input_pos is not None
@@ -421,24 +480,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
421480

422481

423482
class TransformerBlock(nn.Module):
424-
def __init__(self, layer_id: int, args: ModelArgs):
483+
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
425484
super().__init__()
426485
self.use_kv_cache = args.use_kv_cache
427486
self.n_heads = args.n_heads
428487
self.dim = args.dim
429488
self.head_dim = args.dim // args.n_heads
430-
self.attention = Attention(args, layer_id)
489+
self.attention = Attention(args, layer_id, rope)
431490
if args.moe:
432491
self.block_sparse_moe = MOEFeedForward(args)
433492
else:
434493
self.feed_forward = FeedForward(args)
435494
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
436495
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
437496

438-
def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN
439-
h = self.attention.forward(
440-
self.attention_norm(x), freqs_cos, freqs_sin, input_pos
441-
)
497+
def forward(self, x, input_pos=None): # x: 1xN
498+
h = self.attention.forward(self.attention_norm(x), input_pos)
442499

443500
h = x + h
444501
if hasattr(self, "block_sparse_moe"):
@@ -456,33 +513,17 @@ def __init__(self, params: ModelArgs):
456513
self.n_layers = params.n_layers
457514

458515
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
516+
self.rope = Rope(params)
459517
self.layers = torch.nn.ModuleList()
460518
for layer_id in range(params.n_layers):
461-
self.layers.append(TransformerBlock(layer_id, params))
519+
self.layers.append(TransformerBlock(layer_id, params, self.rope))
462520
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
463521
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
464522
self.use_kv_cache = params.use_kv_cache
465523
self.generate_full_logits = params.generate_full_logits
466524
self.max_seq_len = params.max_seq_len
467525
self.input_prune_map = params.input_prune_map
468526
self.output_prune_map = params.output_prune_map
469-
if params.use_hf_rope:
470-
self.precompute_freqs_cis = hf_precompute_freqs_cis
471-
else:
472-
self.precompute_freqs_cis = partial(
473-
precompute_freqs_cis, use_scaled=params.use_scaled_rope
474-
)
475-
freqs_cos, freqs_sin = self.precompute_freqs_cis(
476-
params.dim // params.n_heads,
477-
(
478-
params.max_seq_len # Normal llama2.
479-
if params.ffn_dim_multiplier is None
480-
else params.max_seq_len * 2 # Sharded checkpoint.
481-
),
482-
params.rope_freq_base,
483-
)
484-
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
485-
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
486527

487528
def forward(
488529
self,
@@ -498,42 +539,9 @@ def forward(
498539
)
499540
if tokens is not None and h is None:
500541
h = self.tok_embeddings(tokens)
501-
seqlen = h.shape[1]
502-
503-
if self.use_kv_cache:
504-
assert (
505-
input_pos is not None
506-
), "input_pos must be provided when use_kv_cache is True"
507-
508-
if self.params.enable_dynamic_shape:
509-
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
510-
input_pos_item = input_pos[-1].item()
511-
torch._check_is_size(input_pos_item)
512-
torch._check(input_pos_item < self.params.max_seq_len)
513-
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
514-
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen)
515-
# pyre-ignore: Incompatible parameter type [6]
516-
freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen)
517-
else:
518-
# When not using dynamic shape, use of the .item results in
519-
# symints, due to querying the data from tensor.
520-
# this path avoids that for mps backend, although probably mps backend
521-
# can support dynamic shape?
522-
freqs_cos = self.freqs_cos[input_pos]
523-
freqs_sin = self.freqs_sin[input_pos]
524-
525-
else:
526-
assert input_pos is None, "input_pos is unused when use_kv_cache is False"
527-
freqs_cos = self.freqs_cos[:seqlen]
528-
freqs_sin = self.freqs_sin[:seqlen]
529542

530543
for layer in self.layers:
531-
h = layer(
532-
h,
533-
freqs_cos,
534-
freqs_sin,
535-
input_pos,
536-
)
544+
h = layer(h, input_pos)
537545

538546
if not self.generate_full_logits:
539547
# Only the last logit is used for the new generated token

0 commit comments

Comments
 (0)