Skip to content

Commit 2afcd96

Browse files
navsudfacebook-github-bot
authored andcommitted
apply output layer pruning (#5426)
Summary: Pull Request resolved: #5426 Apply output layer pruning if we are using a model trained with a large output vocabulary to use as a classification task to output only smaller set of vocabulary. The output interface is ensured to be the same as unpruned model. e.g., if the last linear layer has 2048 x 128k shape, and we trained the model to output only 20 output vocab, then we can prune away the last layer to have a shape of 2048 x 20. But we still expand the 1,20 output shape to 1,128k so that the app consuming the model outputs don't need to change. Reviewed By: tarun292, iseeyuan Differential Revision: D62143905 fbshipit-source-id: 95124b37b528a03707ef1192a03aa7e194321b62
1 parent 0a9bbaa commit 2afcd96

File tree

5 files changed

+122
-1
lines changed

5 files changed

+122
-1
lines changed

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ runtime.python_library(
7171
"export_llama_lib.py",
7272
"model.py",
7373
"source_transformation/apply_spin_quant_r1_r2.py",
74+
"source_transformation/prune_output.py",
7475
"source_transformation/quantize.py",
7576
"source_transformation/rms_norm.py",
7677
"source_transformation/rope.py",

examples/models/llama2/export_llama_lib.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,12 @@ def build_args_parser() -> argparse.ArgumentParser:
369369
choices=["cuda", "native"],
370370
help="Use SpinQuant for better quantization performance. Only support cuda and native.",
371371
)
372+
373+
parser.add_argument(
374+
"--output_prune_map",
375+
default=None,
376+
help="path to the output pruning token mapping file (token_map.json)",
377+
)
372378
return parser
373379

374380

@@ -458,6 +464,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
458464
tokenizer_path=args.tokenizer_path,
459465
verbose=args.verbose,
460466
max_seq_len=args.max_seq_length,
467+
output_prune_map_path=args.output_prune_map,
461468
metadata_str=args.metadata,
462469
args=args,
463470
)
@@ -682,6 +689,7 @@ def _load_llama_model(
682689
tokenizer_path: Optional[str] = None,
683690
verbose: bool = False,
684691
max_seq_len: int = 128,
692+
output_prune_map_path: Optional[str] = None,
685693
metadata_str: Optional[str] = None,
686694
args,
687695
) -> "LLMEdgeManager":
@@ -709,6 +717,7 @@ def _load_llama_model(
709717
fairseq2=weight_type == WeightType.FAIRSEQ2,
710718
max_seq_len=max_seq_len,
711719
enable_dynamic_shape=enable_dynamic_shape,
720+
output_prune_map_path=output_prune_map_path,
712721
args=args,
713722
)
714723
state_dict = model.state_dict()

examples/models/llama2/llama_transformer.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from dataclasses import dataclass
1111
from functools import partial
12-
from typing import Optional, Tuple
12+
from typing import Dict, Optional, Tuple
1313

1414
import torch
1515
import torch.nn.functional as F
@@ -102,6 +102,8 @@ class ModelArgs:
102102
# logits for all input tokens.)
103103
generate_full_logits: bool = False
104104
enable_dynamic_shape: bool = False # export model with dynamic shape support
105+
# A dictionary mapping from pruned token-id to original token-id
106+
output_prune_map: Optional[Dict[int, int]] = None
105107
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
106108
rope_theta: Optional[float] = (
107109
None # The official name to override self.rope_freq_base.
@@ -449,6 +451,7 @@ def __init__(self, params: ModelArgs):
449451
self.use_kv_cache = params.use_kv_cache
450452
self.generate_full_logits = params.generate_full_logits
451453
self.max_seq_len = params.max_seq_len
454+
self.output_prune_map = params.output_prune_map
452455
if params.use_hf_rope:
453456
self.precompute_freqs_cis = hf_precompute_freqs_cis
454457
else:
@@ -525,4 +528,27 @@ def forward(
525528
h = self.norm(h)
526529

527530
logits = self.output(h)
531+
532+
if self.output_prune_map is not None:
533+
# expand to original size so that downstream applications can use the logits as-is.
534+
if self.generate_full_logits:
535+
# (1, seq_len, pruned_size) -> (1, seq_len, original_size)
536+
expanded_logits = torch.full(
537+
[logits.shape[0], logits.shape[1], self.vocab_size],
538+
float("-inf"),
539+
device=logits.device,
540+
dtype=logits.dtype,
541+
)
542+
expanded_logits[:, :, list(self.output_prune_map.values())] = logits
543+
else:
544+
# (1, pruned_size) -> (1, original_size)
545+
expanded_logits = torch.full(
546+
[logits.shape[0], self.vocab_size],
547+
float("-inf"),
548+
device=logits.device,
549+
dtype=logits.dtype,
550+
)
551+
expanded_logits[:, list(self.output_prune_map.values())] = logits
552+
logits = expanded_logits
553+
528554
return logits

examples/models/llama2/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(self, **kwargs):
6363
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
6464
self.generate_full_logits = kwargs.get("generate_full_logits", False)
6565
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
66+
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
6667

6768
self.max_seq_len = kwargs.get("max_seq_len", 128)
6869
self.args = kwargs.get("args", None)
@@ -141,6 +142,12 @@ def __init__(self, **kwargs):
141142
)
142143
with open(params_path, "r") as f:
143144
params = json.loads(f.read())
145+
output_prune_map = None
146+
if self.output_prune_map_path is not None:
147+
with open(self.output_prune_map_path, "r") as f:
148+
output_prune_map = json.load(f)
149+
# change keys from string to int (json only supports string keys)
150+
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
144151
max_seq_len = self.max_seq_len
145152
max_batch_size = 1
146153
model_args: ModelArgs = ModelArgs(
@@ -149,6 +156,7 @@ def __init__(self, **kwargs):
149156
use_kv_cache=self.use_kv_cache,
150157
use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
151158
generate_full_logits=self.generate_full_logits,
159+
output_prune_map=output_prune_map,
152160
enable_dynamic_shape=self.enable_dynamic_shape,
153161
**params,
154162
)
@@ -230,6 +238,12 @@ def __init__(self, **kwargs):
230238
print(unexpected)
231239
print("============= /unexpected ================")
232240

241+
# prune the output layer if output_prune_map is provided
242+
if output_prune_map is not None:
243+
from .source_transformation.prune_output import prune_output_vocab
244+
245+
self.model_ = prune_output_vocab(self.model_, output_prune_map)
246+
233247
def get_eager_model(self):
234248
if self.dtype:
235249
# convert to the type of the provided checkpoint
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import numpy as np
10+
11+
import torch
12+
13+
14+
def prune_output_vocab(
15+
model: torch.nn.Module,
16+
token_map: Dict[int, int],
17+
output_layer_name: str = "output",
18+
) -> torch.nn.Module:
19+
"""Prune the model output linear layer while keeping the tokens in the token map.
20+
21+
Note: Pruning is performed in-place.
22+
23+
Args:
24+
model: The model to prune.
25+
token_map: A dictionary mapping from new token ids to the old token ids to preserve.
26+
e.g. {0: 221, 1: 1325, 2: 1542, 3: 1728, 4: 18243}
27+
output_layer_name: name of the output layer to prune
28+
29+
Returns:
30+
The pruned model.
31+
"""
32+
assert hasattr(
33+
model, output_layer_name
34+
), f"Model does not have {output_layer_name} layer"
35+
output_layer = getattr(model, output_layer_name)
36+
assert isinstance(
37+
output_layer, torch.nn.Linear
38+
), "Output layer is not a linear layer"
39+
original_shape = output_layer.weight.shape
40+
input_features = original_shape[1]
41+
num_pruned_tokens = len(token_map)
42+
has_bias = output_layer.bias is not None
43+
weight_dtype = output_layer.weight.dtype
44+
pruned_layer = torch.nn.Linear(input_features, num_pruned_tokens, bias=has_bias)
45+
pruned_layer.to(dtype=weight_dtype)
46+
pruned_layer_weights = np.zeros(pruned_layer.weight.shape, dtype=np.float32)
47+
pruned_layer_bias = None
48+
if has_bias:
49+
pruned_layer_bias = np.zeros(pruned_layer.bias.shape, dtype=np.float32)
50+
for i, token_id in token_map.items():
51+
# Copy the weights and biases from the original layer to the pruned layer
52+
pruned_wt = output_layer.weight[token_id].detach()
53+
if weight_dtype == torch.bfloat16:
54+
pruned_wt = pruned_wt.float()
55+
pruned_layer_weights[i] = pruned_wt.numpy()
56+
if has_bias:
57+
pruned_bias = output_layer.bias[token_id].detach()
58+
if weight_dtype == torch.bfloat16:
59+
pruned_bias = pruned_bias.float()
60+
pruned_layer_bias[i] = pruned_bias.numpy()
61+
with torch.no_grad():
62+
pruned_layer.weight.copy_(
63+
torch.tensor(pruned_layer_weights, dtype=weight_dtype)
64+
)
65+
if has_bias:
66+
pruned_layer.bias.copy_(torch.tensor(pruned_layer_bias, dtype=weight_dtype))
67+
68+
# Replace the original layer with the pruned layer
69+
setattr(model, output_layer_name, pruned_layer)
70+
71+
return model

0 commit comments

Comments
 (0)