Skip to content

etLLM: add options to apply embedding or output. #8653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,24 @@ def __init__(self, params: ModelArgs):
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.apply_embedding = params.apply_embedding
self.apply_output = params.apply_output

self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.tok_embeddings = (
nn.Embedding(params.vocab_size, params.dim)
if self.apply_embedding
else None
)
self.rope = Rope(params)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params, self.rope))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.output = (
nn.Linear(params.dim, params.vocab_size, bias=False)
if self.apply_output
else None
)
self.use_kv_cache = params.use_kv_cache
self.generate_full_logits = params.generate_full_logits
self.max_seq_len = params.max_seq_len
Expand All @@ -195,7 +205,7 @@ def forward(
raise ValueError(
"You cannot specify both tokens and h at the same time, and must specify either one"
)
if tokens is not None and h is None:
if self.apply_embedding and tokens is not None and h is None:
h = self.tok_embeddings(tokens)

if attn_options is None:
Expand All @@ -219,7 +229,8 @@ def forward(

h = self.norm(h)

logits = self.output(h)
if self.apply_output:
logits = self.output(h)

if self.output_prune_map is not None:
# expand to original size so that downstream applications can use the logits as-is.
Expand Down
2 changes: 2 additions & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class ModelArgs:
input_prune_map: Optional[Dict[int, int]] = None
# A dictionary mapping from pruned token-id to original token-id
output_prune_map: Optional[Dict[int, int]] = None
apply_embedding: bool = True # Use embedding inside the transformer
apply_output: bool = True # Use output layer (unembedding) inside the transformer
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
rope_theta: Optional[float] = (
None # The official name to override self.rope_freq_base.
Expand Down
Loading