|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import typing |
| 8 | + |
| 9 | +import torch |
| 10 | + |
| 11 | + |
| 12 | +def rotate_embeddings(model, R1: torch.Tensor) -> None: |
| 13 | + # Rotate the embeddings. |
| 14 | + for W in [model.tok_embeddings]: |
| 15 | + dtype = W.weight.data.dtype |
| 16 | + W_ = W.weight.data.to(device="cpu", dtype=torch.float32) |
| 17 | + W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype) |
| 18 | + |
| 19 | + |
| 20 | +def rotate_attention_inputs(layer, R1) -> None: |
| 21 | + # Rotate the WQ, WK and WV matrices of the self-attention layer. |
| 22 | + for W in [layer.attention.wq, layer.attention.wk, layer.attention.wv]: |
| 23 | + dtype = W.weight.dtype |
| 24 | + W_ = W.weight.to(device="cpu", dtype=torch.float32) |
| 25 | + W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype) |
| 26 | + |
| 27 | + |
| 28 | +def rotate_attention_output(layer, R1) -> None: |
| 29 | + # Rotate output matrix of the self-attention layer. |
| 30 | + W = layer.attention.wo |
| 31 | + dtype = W.weight.data.dtype |
| 32 | + W_ = W.weight.data.to(device="cpu", dtype=torch.float32) |
| 33 | + W.weight.data = torch.matmul(R1.T, W_).to(device="cpu", dtype=dtype) |
| 34 | + if W.bias is not None: |
| 35 | + b = W.bias.data.to(device="cpu", dtype=torch.float32) |
| 36 | + W.bias.data = torch.matmul(R1.T, b).to(device="cpu", dtype=dtype) |
| 37 | + |
| 38 | + |
| 39 | +def rotate_mlp_input(layer, R1): |
| 40 | + # Rotate the MLP input weights. |
| 41 | + mlp_inputs = [layer.feed_forward.w3, layer.feed_forward.w1] |
| 42 | + for W in mlp_inputs: |
| 43 | + dtype = W.weight.dtype |
| 44 | + W_ = W.weight.data.to(device="cpu", dtype=torch.float32) |
| 45 | + W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype) |
| 46 | + |
| 47 | + |
| 48 | +def rotate_mlp_output(layer, R1): |
| 49 | + # Rotate the MLP output weights and bias. |
| 50 | + W = layer.feed_forward.w2 |
| 51 | + dtype = W.weight.data.dtype |
| 52 | + W_ = W.weight.data.to(device="cpu", dtype=torch.float32) |
| 53 | + W.weight.data = torch.matmul(R1.T, W_).to(device="cpu", dtype=dtype) |
| 54 | + |
| 55 | + if W.bias is not None: |
| 56 | + b = W.bias.data.to(device="cpu", dtype=torch.float32) |
| 57 | + W.bias.data = torch.matmul(R1.T, b).to(device="cpu", dtype=dtype) |
| 58 | + |
| 59 | + |
| 60 | +def rotate_head(model, R1: torch.Tensor) -> None: |
| 61 | + # Rotate the head. |
| 62 | + W = model.output |
| 63 | + dtype = W.weight.data.dtype |
| 64 | + W_ = W.weight.data.to(device="cpu", dtype=torch.float32) |
| 65 | + W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype) |
| 66 | + |
| 67 | + |
| 68 | +def rotate_ov_proj(layer, head_dim, R2=None): |
| 69 | + W = layer.attention.wv |
| 70 | + dtype = W.weight.data.dtype |
| 71 | + W_ = W.weight.data.to(device="cpu", dtype=torch.float32).t() |
| 72 | + transposed_shape = W_.shape |
| 73 | + temp = W_.reshape(-1, transposed_shape[-1] // head_dim, head_dim) |
| 74 | + temp = temp.to(torch.float32) @ R2 |
| 75 | + W_ = temp.reshape(transposed_shape).t() |
| 76 | + W.weight.data = W_.to(device="cpu", dtype=dtype) |
| 77 | + |
| 78 | + W = layer.attention.wo |
| 79 | + dtype = W.weight.data.dtype |
| 80 | + W_ = W.weight.data.to(device="cpu", dtype=torch.float32) |
| 81 | + init_shape = W_.shape |
| 82 | + temp = W_.reshape(-1, init_shape[-1] // head_dim, head_dim) |
| 83 | + temp = temp.to(torch.float32) @ R2 |
| 84 | + W_ = temp.reshape(init_shape) |
| 85 | + W.weight.data = W_.to(device="cpu", dtype=dtype) |
| 86 | + |
| 87 | + |
| 88 | +def cleanup_memory() -> None: |
| 89 | + """Run GC and clear GPU memory.""" |
| 90 | + import gc |
| 91 | + |
| 92 | + # gc.collect and empty cache are necessary to clean up GPU memory if the model was distributed |
| 93 | + gc.collect() |
| 94 | + |
| 95 | + |
| 96 | +def get_model_with_r1_r2(optimized_rotation_path: str): |
| 97 | + return lambda model: apply_spin_quant_r1_r2(model, optimized_rotation_path) |
| 98 | + |
| 99 | + |
| 100 | +def apply_spin_quant_r1_r2(model: torch.nn.Module, optimized_rotation_path: str): |
| 101 | + optimized_rotation = torch.load(optimized_rotation_path, weights_only=True) |
| 102 | + R1 = optimized_rotation["R1"].to(torch.float32) |
| 103 | + config = model.params |
| 104 | + num_heads = config.n_heads |
| 105 | + head_dim = config.dim // num_heads |
| 106 | + |
| 107 | + rotate_embeddings(model, R1) |
| 108 | + rotate_head(model, R1) |
| 109 | + cleanup_memory() |
| 110 | + |
| 111 | + for idx, layer in enumerate(model.layers): |
| 112 | + key = f"model.layers.{idx}.self_attn.R2" |
| 113 | + R2 = optimized_rotation[key].to(torch.float32) |
| 114 | + rotate_attention_inputs(layer, R1) |
| 115 | + rotate_attention_output(layer, R1) |
| 116 | + rotate_mlp_input(layer, R1) |
| 117 | + rotate_mlp_output(layer, R1) |
| 118 | + rotate_ov_proj(layer, head_dim, R2=R2) |
| 119 | + return model |
| 120 | + |
| 121 | + |
| 122 | +def fuse_ln_linear( |
| 123 | + layernorm: torch.nn.Module, linear_layers: typing.Iterable[torch.nn.Linear] |
| 124 | +) -> None: |
| 125 | + """ |
| 126 | + fuse the linear operations in Layernorm into the adjacent linear blocks. |
| 127 | + """ |
| 128 | + for linear in linear_layers: |
| 129 | + linear_dtype = linear.weight.dtype |
| 130 | + |
| 131 | + # Calculating new weight and bias |
| 132 | + W_ = linear.weight.data.to(dtype=torch.float32) |
| 133 | + linear.weight.data = (W_ * layernorm.weight.to(dtype=torch.float32)).to( |
| 134 | + linear_dtype |
| 135 | + ) |
| 136 | + |
| 137 | + if hasattr(layernorm, "bias"): |
| 138 | + if linear.bias is None: |
| 139 | + linear.bias = torch.nn.Parameter( |
| 140 | + torch.zeros(linear.out_features, dtype=torch.float32) |
| 141 | + ) |
| 142 | + linear.bias.data = linear.bias.data.to(dtype=torch.float32) + torch.matmul( |
| 143 | + W_, layernorm.bias.to(dtype=torch.float32) |
| 144 | + ) |
| 145 | + linear.bias.data = linear.bias.data.to(linear_dtype) |
| 146 | + |
| 147 | + |
| 148 | +def fuse_layer_norms(model: torch.nn.Module): |
| 149 | + # Embedding fusion |
| 150 | + for W in [model.tok_embeddings]: |
| 151 | + W_ = W.weight.data.to(dtype=torch.float32) |
| 152 | + W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype) |
| 153 | + |
| 154 | + # Fuse the linear operations in Layernorm into the adjacent linear blocks. |
| 155 | + for layer in model.layers: |
| 156 | + # fuse the input layernorms into the linear layers |
| 157 | + fuse_ln_linear(layer.ffn_norm, [layer.feed_forward.w3, layer.feed_forward.w1]) |
| 158 | + fuse_ln_linear( |
| 159 | + layer.attention_norm, |
| 160 | + [ |
| 161 | + layer.attention.wq, |
| 162 | + layer.attention.wk, |
| 163 | + layer.attention.wv, |
| 164 | + ], |
| 165 | + ) |
| 166 | + |
| 167 | + W_norm = layer.ffn_norm.weight.data |
| 168 | + layer.ffn_norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32) |
| 169 | + W_norm = layer.attention_norm.weight.data |
| 170 | + layer.attention_norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32) |
| 171 | + |
| 172 | + fuse_ln_linear( |
| 173 | + model.norm, |
| 174 | + [model.output], |
| 175 | + ) |
| 176 | + W_norm = model.norm.weight.data |
| 177 | + model.norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32) |
| 178 | + |
| 179 | + return model |
0 commit comments