Skip to content

Commit df4181c

Browse files
navsudfacebook-github-bot
authored andcommitted
llama export with input vocab pruning (#6421)
Summary: Pull Request resolved: #6421 D62143905 added llama model export with output vocab pruning. In the similar lines, this diff applies the same for input vocabulary pruning. The assumption here is: we have trained the model with full vocab and we are pruning out the input vocab after the model training, at export time. Reviewed By: iseeyuan Differential Revision: D64723663
1 parent 3ea8538 commit df4181c

File tree

5 files changed

+63
-2
lines changed

5 files changed

+63
-2
lines changed

examples/models/llama/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ runtime.python_library(
8484
"source_transformation/apply_spin_quant_r1_r2.py",
8585
"source_transformation/lora.py",
8686
"source_transformation/pre_quantization.py",
87-
"source_transformation/prune_output.py",
87+
"source_transformation/prune_vocab.py",
8888
"source_transformation/quantize.py",
8989
"source_transformation/quantized_kv_cache.py",
9090
"source_transformation/rms_norm.py",

examples/models/llama/export_llama_lib.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,12 @@ def build_args_parser() -> argparse.ArgumentParser:
436436
default=None,
437437
help="path to the output pruning token mapping file (token_map.json)",
438438
)
439+
440+
parser.add_argument(
441+
"--input_prune_map",
442+
default=None,
443+
help="path to the input pruning token mapping file (token_map.json)",
444+
)
439445
return parser
440446

441447

@@ -524,6 +530,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
524530
tokenizer_path=args.tokenizer_path,
525531
verbose=args.verbose,
526532
max_seq_len=args.max_seq_length,
533+
input_prune_map_path=args.input_prune_map,
527534
output_prune_map_path=args.output_prune_map,
528535
metadata_str=args.metadata,
529536
dtype_override=dtype_override,
@@ -765,6 +772,7 @@ def _load_llama_model(
765772
tokenizer_path: Optional[str] = None,
766773
verbose: bool = False,
767774
max_seq_len: int = 128,
775+
input_prune_map_path: Optional[str] = None,
768776
output_prune_map_path: Optional[str] = None,
769777
metadata_str: Optional[str] = None,
770778
dtype_override: Optional[DType] = None,
@@ -794,6 +802,7 @@ def _load_llama_model(
794802
fairseq2=weight_type == WeightType.FAIRSEQ2,
795803
max_seq_len=max_seq_len,
796804
enable_dynamic_shape=enable_dynamic_shape,
805+
input_prune_map_path=input_prune_map_path,
797806
output_prune_map_path=output_prune_map_path,
798807
args=args,
799808
)

examples/models/llama/llama_transformer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ class ModelArgs:
103103
generate_full_logits: bool = False
104104
enable_dynamic_shape: bool = False # export model with dynamic shape support
105105
# A dictionary mapping from pruned token-id to original token-id
106+
input_prune_map: Optional[Dict[int, int]] = None
107+
# A dictionary mapping from pruned token-id to original token-id
106108
output_prune_map: Optional[Dict[int, int]] = None
107109
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
108110
rope_theta: Optional[float] = (
@@ -461,6 +463,7 @@ def __init__(self, params: ModelArgs):
461463
self.use_kv_cache = params.use_kv_cache
462464
self.generate_full_logits = params.generate_full_logits
463465
self.max_seq_len = params.max_seq_len
466+
self.input_prune_map = params.input_prune_map
464467
self.output_prune_map = params.output_prune_map
465468
if params.use_hf_rope:
466469
self.precompute_freqs_cis = hf_precompute_freqs_cis

examples/models/llama/model.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, **kwargs):
4949
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
5050
self.generate_full_logits = kwargs.get("generate_full_logits", False)
5151
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
52+
self.input_prune_map_path = kwargs.get("input_prune_map_path", None)
5253
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
5354
self.max_seq_len = kwargs.get("max_seq_len", 128)
5455
self.args = kwargs.get("args", None)
@@ -126,13 +127,20 @@ def __init__(self, **kwargs):
126127
output_prune_map = json.load(f)
127128
# Change keys from string to int (json only supports string keys).
128129
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
130+
input_prune_map = None
131+
if self.input_prune_map_path is not None:
132+
with open(self.input_prune_map_path, "r") as f:
133+
input_prune_map = json.load(f)
134+
# Change keys from string to int (json only supports string keys).
135+
input_prune_map = {int(k): v for (k, v) in input_prune_map.items()}
129136

130137
model_args: ModelArgs = ModelArgs(
131138
max_seq_len=self.max_seq_len,
132139
max_batch_size=1,
133140
use_kv_cache=self.use_kv_cache,
134141
use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
135142
generate_full_logits=self.generate_full_logits,
143+
input_prune_map=input_prune_map,
136144
output_prune_map=output_prune_map,
137145
enable_dynamic_shape=self.enable_dynamic_shape,
138146
**params,
@@ -208,10 +216,16 @@ def __init__(self, **kwargs):
208216
print("============= unexpected keys ================")
209217
print(unexpected)
210218
print("============= /unexpected ================")
219+
220+
# Prune the input layer if input_prune_map is provided
221+
if input_prune_map is not None:
222+
from .source_transformation.prune_vocab import prune_input_vocab
223+
224+
self.model_ = prune_input_vocab(self.model_, input_prune_map)
211225

212226
# Prune the output layer if output_prune_map is provided
213227
if output_prune_map is not None:
214-
from .source_transformation.prune_output import prune_output_vocab
228+
from .source_transformation.prune_vocab import prune_output_vocab
215229

216230
self.model_ = prune_output_vocab(self.model_, output_prune_map)
217231

examples/models/llama/source_transformation/prune_output.py renamed to examples/models/llama/source_transformation/prune_vocab.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,38 @@ def prune_output_vocab(
6969
setattr(model, output_layer_name, pruned_layer)
7070

7171
return model
72+
73+
def prune_input_vocab(
74+
model: torch.nn.Module,
75+
token_map: Dict[int, int],
76+
imput_layer_name: str = "tok_embeddings",
77+
) -> torch.nn.Module:
78+
assert hasattr(
79+
model, imput_layer_name
80+
), f"Model does not have {imput_layer_name} layer"
81+
input_layer = getattr(model, imput_layer_name)
82+
assert isinstance(
83+
input_layer, torch.nn.Embedding
84+
), "Input layer is not an Embedding layer"
85+
original_shape = input_layer.weight.shape
86+
num_pruned_tokens = len(token_map)
87+
weight_dtype = input_layer.weight.dtype
88+
pruned_layer = torch.nn.Embedding(num_pruned_tokens, original_shape[1])
89+
pruned_layer.to(dtype=weight_dtype)
90+
pruned_layer_weights = np.zeros(pruned_layer.weight.shape, dtype=np.float32)
91+
for i, token_id in token_map.items():
92+
# Copy the weights from the original layer to the pruned layer
93+
pruned_wt = input_layer.weight[token_id].detach()
94+
if weight_dtype == torch.bfloat16:
95+
pruned_wt = pruned_wt.float()
96+
pruned_layer_weights[i] = pruned_wt.numpy()
97+
with torch.no_grad():
98+
pruned_layer.weight.copy_(
99+
torch.tensor(pruned_layer_weights, dtype=weight_dtype)
100+
)
101+
102+
# Replace the original layer with the pruned layer
103+
setattr(model, imput_layer_name, pruned_layer)
104+
105+
return model
106+

0 commit comments

Comments
 (0)