Skip to content

move rope related logic together #6560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 85 additions & 54 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,81 @@ def __post_init__(self):
self.head_dim = self.dim // self.n_heads


class Rope(torch.nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
if self.params.use_hf_rope:
self.precompute_freqs_cis = hf_precompute_freqs_cis
else:
self.precompute_freqs_cis = partial(
precompute_freqs_cis, use_scaled=self.params.use_scaled_rope
)
freqs_cos, freqs_sin = self.precompute_freqs_cis(
self.params.head_dim,
(
self.params.max_seq_len # Normal llama2.
if self.params.ffn_dim_multiplier is None
else self.params.max_seq_len * 2 # Sharded checkpoint.
),
self.params.rope_freq_base,
)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
if self.params.use_hf_rope:
self.apply_rotary_emb = hf_apply_rotary_emb
else:
self.apply_rotary_emb = RotaryEmbedding()

def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
):
return self.apply_rotary_emb(q, k, freqs_cos, freqs_sin)

def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):
"""
Get the precomputed frequencies for the given input position and sequence length.

Args:
input_pos (torch.Tensor): The input position tensor.
seq_len (int): The sequence length.

Returns:
Tuple[torch.Tensor, torch.Tensor]: The precomputed frequencies for the given input position and sequence length.
"""
if self.params.use_kv_cache:
assert (
input_pos is not None
), "input_pos must be provided when use_kv_cache is True"

if self.params.enable_dynamic_shape:
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
input_pos_item = input_pos[-1].item()
torch._check_is_size(input_pos_item)
torch._check(input_pos_item < self.params.max_seq_len)
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len)
# pyre-ignore: Incompatible parameter type [6]
freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len)
else:
# When not using dynamic shape, use of the .item results in
# symints, due to querying the data from tensor.
# this path avoids that for mps backend, although probably mps backend
# can support dynamic shape?
freqs_cos = self.freqs_cos[input_pos]
freqs_sin = self.freqs_sin[input_pos]

else:
assert input_pos is None, "input_pos is unused when use_kv_cache is False"
freqs_cos = self.freqs_cos[:seq_len]
freqs_sin = self.freqs_sin[:seq_len]
return freqs_cos, freqs_sin


class KVCache(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -266,7 +341,7 @@ def forward(


class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_id: int):
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
super().__init__()
self.use_kv_cache = args.use_kv_cache
self.n_heads = args.n_heads
Expand All @@ -287,6 +362,8 @@ def __init__(self, args: ModelArgs, layer_id: int):

self.layer_id = layer_id

self.rope = rope

causal_mask = torch.tril(
torch.ones(
self.max_seq_len,
Expand All @@ -303,7 +380,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
args.max_seq_len,
self.n_kv_heads,
self.head_dim,
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v
args.enable_dynamic_shape,
)
self.SDPA = SDPA(
Expand All @@ -314,10 +391,6 @@ def __init__(self, args: ModelArgs, layer_id: int):
max_seq_len=self.max_seq_len,
enable_dynamic_shape=args.enable_dynamic_shape,
)
if args.use_hf_rope:
self.apply_rotary_emb = hf_apply_rotary_emb
else:
self.apply_rotary_emb = RotaryEmbedding()

def forward(
self,
Expand All @@ -336,7 +409,7 @@ def forward(
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

# RoPE relative positional embeddings
q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin)
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)

if self.use_kv_cache:
assert input_pos is not None
Expand Down Expand Up @@ -424,13 +497,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
super().__init__()
self.use_kv_cache = args.use_kv_cache
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.head_dim
self.attention = Attention(args, layer_id)
self.attention = Attention(args, layer_id, rope)
if args.moe:
self.block_sparse_moe = MOEFeedForward(args)
else:
Expand Down Expand Up @@ -459,33 +532,17 @@ def __init__(self, params: ModelArgs):
self.n_layers = params.n_layers

self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.rope = Rope(params)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.layers.append(TransformerBlock(layer_id, params, self.rope))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.use_kv_cache = params.use_kv_cache
self.generate_full_logits = params.generate_full_logits
self.max_seq_len = params.max_seq_len
self.input_prune_map = params.input_prune_map
self.output_prune_map = params.output_prune_map
if params.use_hf_rope:
self.precompute_freqs_cis = hf_precompute_freqs_cis
else:
self.precompute_freqs_cis = partial(
precompute_freqs_cis, use_scaled=params.use_scaled_rope
)
freqs_cos, freqs_sin = self.precompute_freqs_cis(
params.head_dim,
(
params.max_seq_len # Normal llama2.
if params.ffn_dim_multiplier is None
else params.max_seq_len * 2 # Sharded checkpoint.
),
params.rope_freq_base,
)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)

def forward(
self,
Expand All @@ -502,33 +559,7 @@ def forward(
if tokens is not None and h is None:
h = self.tok_embeddings(tokens)
seqlen = h.shape[1]

if self.use_kv_cache:
assert (
input_pos is not None
), "input_pos must be provided when use_kv_cache is True"

if self.params.enable_dynamic_shape:
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
input_pos_item = input_pos[-1].item()
torch._check_is_size(input_pos_item)
torch._check(input_pos_item < self.params.max_seq_len)
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen)
# pyre-ignore: Incompatible parameter type [6]
freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen)
else:
# When not using dynamic shape, use of the .item results in
# symints, due to querying the data from tensor.
# this path avoids that for mps backend, although probably mps backend
# can support dynamic shape?
freqs_cos = self.freqs_cos[input_pos]
freqs_sin = self.freqs_sin[input_pos]

else:
assert input_pos is None, "input_pos is unused when use_kv_cache is False"
freqs_cos = self.freqs_cos[:seqlen]
freqs_sin = self.freqs_sin[:seqlen]
freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen)

for layer in self.layers:
h = layer(
Expand Down
28 changes: 16 additions & 12 deletions examples/models/llama/source_transformation/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,27 @@ def materialze_broadcast_of_rope_freq_cis(
module: torch.nn.Module,
):
assert isinstance(module, Transformer)
assert module.freqs_cos.dim() == 2
dim0 = module.freqs_cos.size(0)
dim1 = module.freqs_cos.size(1)
assert module.rope.freqs_cos.dim() == 2
dim0 = module.rope.freqs_cos.size(0)
dim1 = module.rope.freqs_cos.size(1)
module_attention = module.layers[0].attention
assert (
module_attention.n_local_kv_heads == module_attention.n_local_heads
), f"For rope freqs to be materialized for broadcast, q, k, v num heads must match. For q got {module_attention.n_kv_heads} for k got {module_attention.n_local_heads} and v got {module_attention.n_local_kv_heads}"
num_heads = module_attention.n_local_heads
module.freqs_cos = module.freqs_cos.view(dim0, 1, dim1)
module.freqs_cos = module.freqs_cos.expand(dim0, num_heads, dim1).contiguous()
assert module.freqs_sin.dim() == 2
assert dim0 == module.freqs_sin.size(
module.rope.freqs_cos = module.rope.freqs_cos.view(dim0, 1, dim1)
module.rope.freqs_cos = module.rope.freqs_cos.expand(
dim0, num_heads, dim1
).contiguous()
assert module.rope.freqs_sin.dim() == 2
assert dim0 == module.rope.freqs_sin.size(
0
), f"sin and cos freq table sizes must match. Mismatch found at dim 0: {dim0} vs {module.freqs_sin.size(0)}"
assert dim1 == module.freqs_sin.size(
), f"sin and cos freq table sizes must match. Mismatch found at dim 0: {dim0} vs {module.rope.freqs_sin.size(0)}"
assert dim1 == module.rope.freqs_sin.size(
1
), f"sin and cos freq table sizes must match. Mismatch found at dim 1: {dim1} vs {module.freqs_sin.size(1)}"
module.freqs_sin = module.freqs_sin.view(dim0, 1, dim1)
module.freqs_sin = module.freqs_sin.expand(dim0, num_heads, dim1).contiguous()
), f"sin and cos freq table sizes must match. Mismatch found at dim 1: {dim1} vs {module.rope.freqs_sin.size(1)}"
module.rope.freqs_sin = module.rope.freqs_sin.view(dim0, 1, dim1)
module.rope.freqs_sin = module.rope.freqs_sin.expand(
dim0, num_heads, dim1
).contiguous()
return module
Loading