Skip to content

llama export with input vocab pruning #6421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ runtime.python_library(
"source_transformation/apply_spin_quant_r1_r2.py",
"source_transformation/lora.py",
"source_transformation/pre_quantization.py",
"source_transformation/prune_output.py",
"source_transformation/prune_vocab.py",
"source_transformation/quantize.py",
"source_transformation/quantized_kv_cache.py",
"source_transformation/rms_norm.py",
Expand Down
9 changes: 9 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,12 @@ def build_args_parser() -> argparse.ArgumentParser:
default=None,
help="path to the output pruning token mapping file (token_map.json)",
)

parser.add_argument(
"--input_prune_map",
default=None,
help="path to the input pruning token mapping file (token_map.json)",
)
return parser


Expand Down Expand Up @@ -525,6 +531,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
tokenizer_path=args.tokenizer_path,
verbose=args.verbose,
max_seq_len=args.max_seq_length,
input_prune_map_path=args.input_prune_map,
output_prune_map_path=args.output_prune_map,
metadata_str=args.metadata,
dtype_override=dtype_override,
Expand Down Expand Up @@ -766,6 +773,7 @@ def _load_llama_model(
tokenizer_path: Optional[str] = None,
verbose: bool = False,
max_seq_len: int = 128,
input_prune_map_path: Optional[str] = None,
output_prune_map_path: Optional[str] = None,
metadata_str: Optional[str] = None,
dtype_override: Optional[DType] = None,
Expand Down Expand Up @@ -795,6 +803,7 @@ def _load_llama_model(
fairseq2=weight_type == WeightType.FAIRSEQ2,
max_seq_len=max_seq_len,
enable_dynamic_shape=enable_dynamic_shape,
input_prune_map_path=input_prune_map_path,
output_prune_map_path=output_prune_map_path,
args=args,
)
Expand Down
3 changes: 3 additions & 0 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class ModelArgs:
generate_full_logits: bool = False
enable_dynamic_shape: bool = False # export model with dynamic shape support
# A dictionary mapping from pruned token-id to original token-id
input_prune_map: Optional[Dict[int, int]] = None
# A dictionary mapping from pruned token-id to original token-id
output_prune_map: Optional[Dict[int, int]] = None
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
rope_theta: Optional[float] = (
Expand Down Expand Up @@ -461,6 +463,7 @@ def __init__(self, params: ModelArgs):
self.use_kv_cache = params.use_kv_cache
self.generate_full_logits = params.generate_full_logits
self.max_seq_len = params.max_seq_len
self.input_prune_map = params.input_prune_map
self.output_prune_map = params.output_prune_map
if params.use_hf_rope:
self.precompute_freqs_cis = hf_precompute_freqs_cis
Expand Down
16 changes: 15 additions & 1 deletion examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, **kwargs):
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
self.generate_full_logits = kwargs.get("generate_full_logits", False)
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
self.input_prune_map_path = kwargs.get("input_prune_map_path", None)
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
self.max_seq_len = kwargs.get("max_seq_len", 128)
self.args = kwargs.get("args", None)
Expand Down Expand Up @@ -126,13 +127,20 @@ def __init__(self, **kwargs):
output_prune_map = json.load(f)
# Change keys from string to int (json only supports string keys).
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
input_prune_map = None
if self.input_prune_map_path is not None:
with open(self.input_prune_map_path, "r") as f:
input_prune_map = json.load(f)
# Change keys from string to int (json only supports string keys).
input_prune_map = {int(k): v for (k, v) in input_prune_map.items()}

model_args: ModelArgs = ModelArgs(
max_seq_len=self.max_seq_len,
max_batch_size=1,
use_kv_cache=self.use_kv_cache,
use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
generate_full_logits=self.generate_full_logits,
input_prune_map=input_prune_map,
output_prune_map=output_prune_map,
enable_dynamic_shape=self.enable_dynamic_shape,
**params,
Expand Down Expand Up @@ -209,9 +217,15 @@ def __init__(self, **kwargs):
print(unexpected)
print("============= /unexpected ================")

# Prune the input layer if input_prune_map is provided
if input_prune_map is not None:
from .source_transformation.prune_vocab import prune_input_vocab

self.model_ = prune_input_vocab(self.model_, input_prune_map)

# Prune the output layer if output_prune_map is provided
if output_prune_map is not None:
from .source_transformation.prune_output import prune_output_vocab
from .source_transformation.prune_vocab import prune_output_vocab

self.model_ = prune_output_vocab(self.model_, output_prune_map)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,51 @@ def prune_output_vocab(
setattr(model, output_layer_name, pruned_layer)

return model


def prune_input_vocab(
model: torch.nn.Module,
token_map: Dict[int, int],
imput_layer_name: str = "tok_embeddings",
) -> torch.nn.Module:
"""Prune the model input embedding layer while keeping the tokens in the token map.

Note: Pruning is performed in-place.

Args:
model: The model to prune.
token_map: A dictionary mapping from new token ids to the old token ids to preserve.
e.g. {0: 221, 1: 1325, 2: 1542, 3: 1728, 4: 18243}
imput_layer_name: name of the input embedding layer to prune

Returns:
The pruned model.
"""
assert hasattr(
model, imput_layer_name
), f"Model does not have {imput_layer_name} layer"
input_layer = getattr(model, imput_layer_name)
assert isinstance(
input_layer, torch.nn.Embedding
), "Input layer is not an Embedding layer"
original_shape = input_layer.weight.shape
num_pruned_tokens = len(token_map)
weight_dtype = input_layer.weight.dtype
pruned_layer = torch.nn.Embedding(num_pruned_tokens, original_shape[1])
pruned_layer.to(dtype=weight_dtype)
pruned_layer_weights = np.zeros(pruned_layer.weight.shape, dtype=np.float32)
for i, token_id in token_map.items():
# Copy the weights from the original layer to the pruned layer
pruned_wt = input_layer.weight[token_id].detach()
if weight_dtype == torch.bfloat16:
pruned_wt = pruned_wt.float()
pruned_layer_weights[i] = pruned_wt.numpy()
with torch.no_grad():
pruned_layer.weight.copy_(
torch.tensor(pruned_layer_weights, dtype=weight_dtype)
)

# Replace the original layer with the pruned layer
setattr(model, imput_layer_name, pruned_layer)

return model
Loading