|
| 1 | +""" |
| 2 | +Implements the AWQ for llama.cpp use cases. |
| 3 | +Original paper: https://arxiv.org/abs/2306.00978 |
| 4 | +
|
| 5 | +This code is based on versions of the AWQ implementation found in the following repositories: |
| 6 | +* https://github.com/mit-han-lab/llm-awq |
| 7 | +* https://github.com/casper-hansen/AutoAWQ |
| 8 | +""" |
| 9 | + |
| 10 | +import os |
| 11 | +import torch |
| 12 | +import torch.nn as nn |
| 13 | + |
| 14 | +from transformers import AutoModelForCausalLM, AutoConfig |
| 15 | +from transformers.models.bloom.modeling_bloom import BloomGelu |
| 16 | +from transformers.models.llama.modeling_llama import LlamaRMSNorm |
| 17 | +from transformers.activations import GELUActivation |
| 18 | + |
| 19 | + |
| 20 | +class ScaledActivation(nn.Module): |
| 21 | + """ |
| 22 | + ScaledActivation module wraps an existing activation function and applies a |
| 23 | + scale factor to its output. |
| 24 | +
|
| 25 | + Args: |
| 26 | + module (nn.Module): The activation function to be scaled. |
| 27 | + scales (torch.Tensor): A tensor of size (num_features,) containing the initial |
| 28 | + scale factors for each feature. |
| 29 | +
|
| 30 | + Returns: |
| 31 | + torch.Tensor: The scaled output of the activation function. |
| 32 | + """ |
| 33 | + |
| 34 | + def __init__(self, module, scales): |
| 35 | + super().__init__() |
| 36 | + self.act = module |
| 37 | + self.scales = nn.Parameter(scales.data) |
| 38 | + |
| 39 | + def forward(self, x): |
| 40 | + return self.act(x) / self.scales.view(1, 1, -1).to(x.device) |
| 41 | + |
| 42 | + |
| 43 | +def set_op_by_name(layer, name, new_module): |
| 44 | + """ |
| 45 | + Set the new module for given module's name. |
| 46 | +
|
| 47 | + Args: |
| 48 | + layer (nn.Module): The layer in which to replace the submodule. |
| 49 | + name (str): The path to the submodule to be replaced, using dot notation |
| 50 | + to access nested modules. |
| 51 | + new_module (nn.Module): The new module to replace the existing one. |
| 52 | + """ |
| 53 | + levels = name.split(".") |
| 54 | + if len(levels) > 1: |
| 55 | + mod_ = layer |
| 56 | + for l_idx in range(len(levels) - 1): |
| 57 | + if levels[l_idx].isdigit(): |
| 58 | + mod_ = mod_[int(levels[l_idx])] |
| 59 | + else: |
| 60 | + mod_ = getattr(mod_, levels[l_idx]) |
| 61 | + setattr(mod_, levels[-1], new_module) |
| 62 | + else: |
| 63 | + setattr(layer, name, new_module) |
| 64 | + |
| 65 | + |
| 66 | +def get_op_by_name(module, op_name): |
| 67 | + """ |
| 68 | + Retrieves a submodule within a given layer based on its name. |
| 69 | +
|
| 70 | + Args: |
| 71 | + module (nn.Module): The layer containing the submodule to find. |
| 72 | + op_name (str): The name of the submodule. |
| 73 | +
|
| 74 | + Returns: |
| 75 | + nn.Module: The requested submodule found within the given layer. |
| 76 | +
|
| 77 | + Raises: |
| 78 | + ValueError: If the specified submodule cannot be found within the layer. |
| 79 | + """ |
| 80 | + for name, m in module.named_modules(): |
| 81 | + if name == op_name: |
| 82 | + return m |
| 83 | + raise ValueError(f"Cannot find op {op_name} in module {module}") |
| 84 | + |
| 85 | + |
| 86 | +@torch.no_grad() |
| 87 | +def scale_ln_fcs(ln, fcs, scales): |
| 88 | + """ |
| 89 | + Scales the weights of a LayerNorm and a list of fully-connected layers proportionally. |
| 90 | +
|
| 91 | + Args: |
| 92 | + ln (nn.LayerNorm): The LayerNorm module to be scaled. |
| 93 | + fcs (List[nn.Linear]): A list of fully-connected layers to be scaled. |
| 94 | + scales (torch.Tensor): A 1D tensor of size (num_features,). |
| 95 | + """ |
| 96 | + |
| 97 | + if not isinstance(fcs, list): |
| 98 | + fcs = [fcs] |
| 99 | + |
| 100 | + scales = scales.to(ln.weight.device) |
| 101 | + |
| 102 | + ln.weight.div_(scales) |
| 103 | + if hasattr(ln, "bias") and ln.bias is not None: |
| 104 | + ln.bias.div_(scales) |
| 105 | + |
| 106 | + for fc in fcs: |
| 107 | + fc.weight.mul_(scales.view(1, -1)) |
| 108 | + |
| 109 | + for p in ln.parameters(): |
| 110 | + assert torch.isnan(p).sum() == 0 |
| 111 | + for fc in fcs: |
| 112 | + for p in fc.parameters(): |
| 113 | + assert torch.isnan(p).sum() == 0 |
| 114 | + |
| 115 | + |
| 116 | +@torch.no_grad() |
| 117 | +def scale_fc_fc(fc1, fc2, scales): |
| 118 | + """ |
| 119 | + Scales the weights of two fully-connected layers in a specific pattern. |
| 120 | +
|
| 121 | + Args: |
| 122 | + fc1 (nn.Linear): The first fully-connected layer to be scaled. |
| 123 | + fc2 (nn.Linear): The second fully-connected layer to be scaled. |
| 124 | + scales (torch.Tensor): A 1D tensor of size (num_features,). |
| 125 | + """ |
| 126 | + assert isinstance(fc1, nn.Linear) |
| 127 | + assert isinstance(fc2, nn.Linear) |
| 128 | + |
| 129 | + scales = scales.to(fc1.weight.device) |
| 130 | + |
| 131 | + fc1.weight[-scales.size(0):].div_(scales.view(-1, 1)) |
| 132 | + if fc1.bias is not None: |
| 133 | + fc1.bias.div_(scales.view(-1)) |
| 134 | + |
| 135 | + fc2.weight.mul_(scales.view(1, -1)) |
| 136 | + |
| 137 | + for p in fc1.parameters(): |
| 138 | + assert torch.isnan(p).sum() == 0 |
| 139 | + for p in fc2.parameters(): |
| 140 | + assert torch.isnan(p).sum() == 0 |
| 141 | + |
| 142 | + |
| 143 | +@torch.no_grad() |
| 144 | +def scale_gelu_fc(gelu, fc, scales): |
| 145 | + """ |
| 146 | + Scales the weight of a GELU activation and a fully-connected layer proportionally. |
| 147 | +
|
| 148 | + Args: |
| 149 | + gelu (Union[nn.GELU, BloomGelu, GELUActivation]): The GELU activation module to be scaled. |
| 150 | + fc (nn.Linear): The fully-connected layer to be scaled. |
| 151 | + scales (torch.Tensor): A 1D tensor of size (num_features,). |
| 152 | +
|
| 153 | + Raises: |
| 154 | + TypeError: If the `gelu` module is not of type `nn.GELU`, `BloomGelu`, or `GELUActivation`. |
| 155 | + TypeError: If the `fc` module is not of type `nn.Linear`. |
| 156 | + """ |
| 157 | + assert isinstance(gelu, (nn.GELU, BloomGelu, GELUActivation)) |
| 158 | + assert isinstance(fc, nn.Linear) |
| 159 | + |
| 160 | + fc.weight.mul_(scales.view(1, -1).to(fc.weight.device)) |
| 161 | + |
| 162 | + for p in fc.parameters(): |
| 163 | + assert torch.isnan(p).sum() == 0 |
| 164 | + |
| 165 | + |
| 166 | +def apply_scale(module, scales_list, input_feat_dict=None): |
| 167 | + """ |
| 168 | + Applies different scaling strategies to layers based on their type and hierarchy within a given module. |
| 169 | +
|
| 170 | + Args: |
| 171 | + module (nn.Module): The module containing the layers to be scaled. |
| 172 | + scales_list (List[Tuple[str, List[str], torch.Tensor]]): A list of tuples containing: |
| 173 | + * prev_op_name (str): The name of the preceding operation or module, |
| 174 | + relative to which the layers to be scaled are located. |
| 175 | + * layer_names (List[str]): A list of names of the layers to be scaled, relative to the preceding operation. |
| 176 | + * scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature. |
| 177 | + input_feat_dict (Optional[Dict[str, torch.Tensor]]): A dictionary mapping layer names to their corresponding |
| 178 | + input features (optional). |
| 179 | + """ |
| 180 | + for prev_op_name, layer_names, scales in scales_list: |
| 181 | + prev_op = get_op_by_name(module, prev_op_name) |
| 182 | + layers = [get_op_by_name(module, name) for name in layer_names] |
| 183 | + |
| 184 | + prev_op.cuda() |
| 185 | + for layer in layers: |
| 186 | + layer.cuda() |
| 187 | + scales.cuda() |
| 188 | + |
| 189 | + if isinstance(prev_op, nn.Linear): |
| 190 | + assert len(layers) == 1 |
| 191 | + scale_fc_fc(prev_op, layers[0], scales) |
| 192 | + elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)) or "rmsnorm" in str(prev_op.__class__).lower(): |
| 193 | + scale_ln_fcs(prev_op, layers, scales) |
| 194 | + elif isinstance(prev_op, (nn.GELU, BloomGelu, GELUActivation)): |
| 195 | + new_module = ScaledActivation(prev_op, scales) |
| 196 | + set_op_by_name(module, prev_op_name, new_module) |
| 197 | + scale_gelu_fc(prev_op, layers[0], scales) |
| 198 | + else: |
| 199 | + raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!") |
| 200 | + |
| 201 | + # apply the scaling to input feat if given; prepare it for clipping |
| 202 | + if input_feat_dict is not None: |
| 203 | + for layer_name in layer_names: |
| 204 | + inp = input_feat_dict[layer_name] |
| 205 | + inp.div_(scales.view(1, -1).to(inp.device)) |
| 206 | + |
| 207 | + prev_op.cpu() |
| 208 | + for layer in layers: |
| 209 | + layer.cpu() |
| 210 | + scales.cpu() |
| 211 | + |
| 212 | + |
| 213 | +@torch.no_grad() |
| 214 | +def apply_clip(module, clip_list): |
| 215 | + """ |
| 216 | + Applies element-wise clipping to the weight of a specific layer within a given module. |
| 217 | +
|
| 218 | + Args: |
| 219 | + module (nn.Module): The module containing the layer to be clipped. |
| 220 | + clip_list (List[Tuple[str, torch.Tensor]]): A list of tuples containing: |
| 221 | + * name (str): The name of the layer to be clipped, relative to the root of the module. |
| 222 | + * max_val (torch.Tensor): A 1D or 2D tensor defining the upper bound for each element of the layer's weight. |
| 223 | + """ |
| 224 | + for name, max_val in clip_list: |
| 225 | + layer = get_op_by_name(module, name) |
| 226 | + layer.cuda() |
| 227 | + max_val = max_val.to(layer.weight.device) |
| 228 | + org_shape = layer.weight.shape |
| 229 | + layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1) |
| 230 | + layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val) |
| 231 | + layer.weight.data = layer.weight.data.reshape(org_shape) |
| 232 | + layer.cpu() |
| 233 | + |
| 234 | + |
| 235 | +def add_scale_weights(model_path, scale_path, tmp_path): |
| 236 | + """ |
| 237 | + Adds pre-computed Activation Weight Quantization (AWQ) results to a model, |
| 238 | + including scaling factors and clipping bounds. |
| 239 | +
|
| 240 | + Args: |
| 241 | + model_path (str): Path to the pre-trained model to be equipped with AWQ. |
| 242 | + scale_path (str): Path to the AWQ scale factors (.pt file). |
| 243 | + tmp_path (str): Path to the temporary directory where the equipped model will be saved. |
| 244 | + """ |
| 245 | + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) |
| 246 | + model = AutoModelForCausalLM.from_pretrained( |
| 247 | + model_path, config=config, trust_remote_code=True |
| 248 | + ) |
| 249 | + model.eval() |
| 250 | + awq_results = torch.load(str(scale_path), map_location="cpu") |
| 251 | + apply_scale(model, awq_results["scale"]) |
| 252 | + apply_clip(model, awq_results["clip"]) |
| 253 | + model.save_pretrained(str(tmp_path)) |
| 254 | + os.system(f"cp {str(model_path)}/tokenizer* {str(tmp_path)}") |
0 commit comments