Skip to content

Commit af8e4d7

Browse files
navsudfacebook-github-bot
authored andcommitted
llama export with input vocab pruning (#6421)
Summary: Pull Request resolved: #6421 D62143905 added llama model export with output vocab pruning. In the similar lines, this diff applies the same for input vocabulary pruning. The assumption here is: we have trained the model with full vocab and we are pruning out the input vocab after the model training, at export time. Reviewed By: iseeyuan Differential Revision: D64723663
1 parent ca47839 commit af8e4d7

File tree

5 files changed

+75
-2
lines changed

5 files changed

+75
-2
lines changed

examples/models/llama/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ runtime.python_library(
8484
"source_transformation/apply_spin_quant_r1_r2.py",
8585
"source_transformation/lora.py",
8686
"source_transformation/pre_quantization.py",
87-
"source_transformation/prune_output.py",
87+
"source_transformation/prune_vocab.py",
8888
"source_transformation/quantize.py",
8989
"source_transformation/quantized_kv_cache.py",
9090
"source_transformation/rms_norm.py",

examples/models/llama/export_llama_lib.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,12 @@ def build_args_parser() -> argparse.ArgumentParser:
437437
default=None,
438438
help="path to the output pruning token mapping file (token_map.json)",
439439
)
440+
441+
parser.add_argument(
442+
"--input_prune_map",
443+
default=None,
444+
help="path to the input pruning token mapping file (token_map.json)",
445+
)
440446
return parser
441447

442448

@@ -525,6 +531,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
525531
tokenizer_path=args.tokenizer_path,
526532
verbose=args.verbose,
527533
max_seq_len=args.max_seq_length,
534+
input_prune_map_path=args.input_prune_map,
528535
output_prune_map_path=args.output_prune_map,
529536
metadata_str=args.metadata,
530537
dtype_override=dtype_override,
@@ -766,6 +773,7 @@ def _load_llama_model(
766773
tokenizer_path: Optional[str] = None,
767774
verbose: bool = False,
768775
max_seq_len: int = 128,
776+
input_prune_map_path: Optional[str] = None,
769777
output_prune_map_path: Optional[str] = None,
770778
metadata_str: Optional[str] = None,
771779
dtype_override: Optional[DType] = None,
@@ -795,6 +803,7 @@ def _load_llama_model(
795803
fairseq2=weight_type == WeightType.FAIRSEQ2,
796804
max_seq_len=max_seq_len,
797805
enable_dynamic_shape=enable_dynamic_shape,
806+
input_prune_map_path=input_prune_map_path,
798807
output_prune_map_path=output_prune_map_path,
799808
args=args,
800809
)

examples/models/llama/llama_transformer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ class ModelArgs:
103103
generate_full_logits: bool = False
104104
enable_dynamic_shape: bool = False # export model with dynamic shape support
105105
# A dictionary mapping from pruned token-id to original token-id
106+
input_prune_map: Optional[Dict[int, int]] = None
107+
# A dictionary mapping from pruned token-id to original token-id
106108
output_prune_map: Optional[Dict[int, int]] = None
107109
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
108110
rope_theta: Optional[float] = (
@@ -461,6 +463,7 @@ def __init__(self, params: ModelArgs):
461463
self.use_kv_cache = params.use_kv_cache
462464
self.generate_full_logits = params.generate_full_logits
463465
self.max_seq_len = params.max_seq_len
466+
self.input_prune_map = params.input_prune_map
464467
self.output_prune_map = params.output_prune_map
465468
if params.use_hf_rope:
466469
self.precompute_freqs_cis = hf_precompute_freqs_cis

examples/models/llama/model.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, **kwargs):
4949
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
5050
self.generate_full_logits = kwargs.get("generate_full_logits", False)
5151
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
52+
self.input_prune_map_path = kwargs.get("input_prune_map_path", None)
5253
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
5354
self.max_seq_len = kwargs.get("max_seq_len", 128)
5455
self.args = kwargs.get("args", None)
@@ -126,13 +127,20 @@ def __init__(self, **kwargs):
126127
output_prune_map = json.load(f)
127128
# Change keys from string to int (json only supports string keys).
128129
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
130+
input_prune_map = None
131+
if self.input_prune_map_path is not None:
132+
with open(self.input_prune_map_path, "r") as f:
133+
input_prune_map = json.load(f)
134+
# Change keys from string to int (json only supports string keys).
135+
input_prune_map = {int(k): v for (k, v) in input_prune_map.items()}
129136

130137
model_args: ModelArgs = ModelArgs(
131138
max_seq_len=self.max_seq_len,
132139
max_batch_size=1,
133140
use_kv_cache=self.use_kv_cache,
134141
use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
135142
generate_full_logits=self.generate_full_logits,
143+
input_prune_map=input_prune_map,
136144
output_prune_map=output_prune_map,
137145
enable_dynamic_shape=self.enable_dynamic_shape,
138146
**params,
@@ -208,10 +216,16 @@ def __init__(self, **kwargs):
208216
print("============= unexpected keys ================")
209217
print(unexpected)
210218
print("============= /unexpected ================")
219+
220+
# Prune the input layer if input_prune_map is provided
221+
if input_prune_map is not None:
222+
from .source_transformation.prune_vocab import prune_input_vocab
223+
224+
self.model_ = prune_input_vocab(self.model_, input_prune_map)
211225

212226
# Prune the output layer if output_prune_map is provided
213227
if output_prune_map is not None:
214-
from .source_transformation.prune_output import prune_output_vocab
228+
from .source_transformation.prune_vocab import prune_output_vocab
215229

216230
self.model_ = prune_output_vocab(self.model_, output_prune_map)
217231

examples/models/llama/source_transformation/prune_output.py renamed to examples/models/llama/source_transformation/prune_vocab.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,50 @@ def prune_output_vocab(
6969
setattr(model, output_layer_name, pruned_layer)
7070

7171
return model
72+
73+
def prune_input_vocab(
74+
model: torch.nn.Module,
75+
token_map: Dict[int, int],
76+
imput_layer_name: str = "tok_embeddings",
77+
) -> torch.nn.Module:
78+
"""Prune the model input embedding layer while keeping the tokens in the token map.
79+
80+
Note: Pruning is performed in-place.
81+
82+
Args:
83+
model: The model to prune.
84+
token_map: A dictionary mapping from new token ids to the old token ids to preserve.
85+
e.g. {0: 221, 1: 1325, 2: 1542, 3: 1728, 4: 18243}
86+
imput_layer_name: name of the input embedding layer to prune
87+
88+
Returns:
89+
The pruned model.
90+
"""
91+
assert hasattr(
92+
model, imput_layer_name
93+
), f"Model does not have {imput_layer_name} layer"
94+
input_layer = getattr(model, imput_layer_name)
95+
assert isinstance(
96+
input_layer, torch.nn.Embedding
97+
), "Input layer is not an Embedding layer"
98+
original_shape = input_layer.weight.shape
99+
num_pruned_tokens = len(token_map)
100+
weight_dtype = input_layer.weight.dtype
101+
pruned_layer = torch.nn.Embedding(num_pruned_tokens, original_shape[1])
102+
pruned_layer.to(dtype=weight_dtype)
103+
pruned_layer_weights = np.zeros(pruned_layer.weight.shape, dtype=np.float32)
104+
for i, token_id in token_map.items():
105+
# Copy the weights from the original layer to the pruned layer
106+
pruned_wt = input_layer.weight[token_id].detach()
107+
if weight_dtype == torch.bfloat16:
108+
pruned_wt = pruned_wt.float()
109+
pruned_layer_weights[i] = pruned_wt.numpy()
110+
with torch.no_grad():
111+
pruned_layer.weight.copy_(
112+
torch.tensor(pruned_layer_weights, dtype=weight_dtype)
113+
)
114+
115+
# Replace the original layer with the pruned layer
116+
setattr(model, imput_layer_name, pruned_layer)
117+
118+
return model

0 commit comments

Comments
 (0)