Skip to content

Commit 89ba47a

Browse files
authored
llama export with input vocab pruning
Differential Revision: D64723663 Pull Request resolved: #6421
1 parent 0aa802d commit 89ba47a

File tree

5 files changed

+76
-2
lines changed

5 files changed

+76
-2
lines changed

examples/models/llama/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ runtime.python_library(
8484
"source_transformation/apply_spin_quant_r1_r2.py",
8585
"source_transformation/lora.py",
8686
"source_transformation/pre_quantization.py",
87-
"source_transformation/prune_output.py",
87+
"source_transformation/prune_vocab.py",
8888
"source_transformation/quantize.py",
8989
"source_transformation/quantized_kv_cache.py",
9090
"source_transformation/rms_norm.py",

examples/models/llama/export_llama_lib.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,12 @@ def build_args_parser() -> argparse.ArgumentParser:
437437
default=None,
438438
help="path to the output pruning token mapping file (token_map.json)",
439439
)
440+
441+
parser.add_argument(
442+
"--input_prune_map",
443+
default=None,
444+
help="path to the input pruning token mapping file (token_map.json)",
445+
)
440446
return parser
441447

442448

@@ -525,6 +531,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
525531
tokenizer_path=args.tokenizer_path,
526532
verbose=args.verbose,
527533
max_seq_len=args.max_seq_length,
534+
input_prune_map_path=args.input_prune_map,
528535
output_prune_map_path=args.output_prune_map,
529536
metadata_str=args.metadata,
530537
dtype_override=dtype_override,
@@ -766,6 +773,7 @@ def _load_llama_model(
766773
tokenizer_path: Optional[str] = None,
767774
verbose: bool = False,
768775
max_seq_len: int = 128,
776+
input_prune_map_path: Optional[str] = None,
769777
output_prune_map_path: Optional[str] = None,
770778
metadata_str: Optional[str] = None,
771779
dtype_override: Optional[DType] = None,
@@ -795,6 +803,7 @@ def _load_llama_model(
795803
fairseq2=weight_type == WeightType.FAIRSEQ2,
796804
max_seq_len=max_seq_len,
797805
enable_dynamic_shape=enable_dynamic_shape,
806+
input_prune_map_path=input_prune_map_path,
798807
output_prune_map_path=output_prune_map_path,
799808
args=args,
800809
)

examples/models/llama/llama_transformer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ class ModelArgs:
103103
generate_full_logits: bool = False
104104
enable_dynamic_shape: bool = False # export model with dynamic shape support
105105
# A dictionary mapping from pruned token-id to original token-id
106+
input_prune_map: Optional[Dict[int, int]] = None
107+
# A dictionary mapping from pruned token-id to original token-id
106108
output_prune_map: Optional[Dict[int, int]] = None
107109
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
108110
rope_theta: Optional[float] = (
@@ -461,6 +463,7 @@ def __init__(self, params: ModelArgs):
461463
self.use_kv_cache = params.use_kv_cache
462464
self.generate_full_logits = params.generate_full_logits
463465
self.max_seq_len = params.max_seq_len
466+
self.input_prune_map = params.input_prune_map
464467
self.output_prune_map = params.output_prune_map
465468
if params.use_hf_rope:
466469
self.precompute_freqs_cis = hf_precompute_freqs_cis

examples/models/llama/model.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, **kwargs):
4949
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
5050
self.generate_full_logits = kwargs.get("generate_full_logits", False)
5151
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
52+
self.input_prune_map_path = kwargs.get("input_prune_map_path", None)
5253
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
5354
self.max_seq_len = kwargs.get("max_seq_len", 128)
5455
self.args = kwargs.get("args", None)
@@ -126,13 +127,20 @@ def __init__(self, **kwargs):
126127
output_prune_map = json.load(f)
127128
# Change keys from string to int (json only supports string keys).
128129
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
130+
input_prune_map = None
131+
if self.input_prune_map_path is not None:
132+
with open(self.input_prune_map_path, "r") as f:
133+
input_prune_map = json.load(f)
134+
# Change keys from string to int (json only supports string keys).
135+
input_prune_map = {int(k): v for (k, v) in input_prune_map.items()}
129136

130137
model_args: ModelArgs = ModelArgs(
131138
max_seq_len=self.max_seq_len,
132139
max_batch_size=1,
133140
use_kv_cache=self.use_kv_cache,
134141
use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
135142
generate_full_logits=self.generate_full_logits,
143+
input_prune_map=input_prune_map,
136144
output_prune_map=output_prune_map,
137145
enable_dynamic_shape=self.enable_dynamic_shape,
138146
**params,
@@ -209,9 +217,15 @@ def __init__(self, **kwargs):
209217
print(unexpected)
210218
print("============= /unexpected ================")
211219

220+
# Prune the input layer if input_prune_map is provided
221+
if input_prune_map is not None:
222+
from .source_transformation.prune_vocab import prune_input_vocab
223+
224+
self.model_ = prune_input_vocab(self.model_, input_prune_map)
225+
212226
# Prune the output layer if output_prune_map is provided
213227
if output_prune_map is not None:
214-
from .source_transformation.prune_output import prune_output_vocab
228+
from .source_transformation.prune_vocab import prune_output_vocab
215229

216230
self.model_ = prune_output_vocab(self.model_, output_prune_map)
217231

examples/models/llama/source_transformation/prune_output.py renamed to examples/models/llama/source_transformation/prune_vocab.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,51 @@ def prune_output_vocab(
6969
setattr(model, output_layer_name, pruned_layer)
7070

7171
return model
72+
73+
74+
def prune_input_vocab(
75+
model: torch.nn.Module,
76+
token_map: Dict[int, int],
77+
imput_layer_name: str = "tok_embeddings",
78+
) -> torch.nn.Module:
79+
"""Prune the model input embedding layer while keeping the tokens in the token map.
80+
81+
Note: Pruning is performed in-place.
82+
83+
Args:
84+
model: The model to prune.
85+
token_map: A dictionary mapping from new token ids to the old token ids to preserve.
86+
e.g. {0: 221, 1: 1325, 2: 1542, 3: 1728, 4: 18243}
87+
imput_layer_name: name of the input embedding layer to prune
88+
89+
Returns:
90+
The pruned model.
91+
"""
92+
assert hasattr(
93+
model, imput_layer_name
94+
), f"Model does not have {imput_layer_name} layer"
95+
input_layer = getattr(model, imput_layer_name)
96+
assert isinstance(
97+
input_layer, torch.nn.Embedding
98+
), "Input layer is not an Embedding layer"
99+
original_shape = input_layer.weight.shape
100+
num_pruned_tokens = len(token_map)
101+
weight_dtype = input_layer.weight.dtype
102+
pruned_layer = torch.nn.Embedding(num_pruned_tokens, original_shape[1])
103+
pruned_layer.to(dtype=weight_dtype)
104+
pruned_layer_weights = np.zeros(pruned_layer.weight.shape, dtype=np.float32)
105+
for i, token_id in token_map.items():
106+
# Copy the weights from the original layer to the pruned layer
107+
pruned_wt = input_layer.weight[token_id].detach()
108+
if weight_dtype == torch.bfloat16:
109+
pruned_wt = pruned_wt.float()
110+
pruned_layer_weights[i] = pruned_wt.numpy()
111+
with torch.no_grad():
112+
pruned_layer.weight.copy_(
113+
torch.tensor(pruned_layer_weights, dtype=weight_dtype)
114+
)
115+
116+
# Replace the original layer with the pruned layer
117+
setattr(model, imput_layer_name, pruned_layer)
118+
119+
return model

0 commit comments

Comments
 (0)