|
9 | 9 | import logging
|
10 | 10 | import os
|
11 | 11 | from pathlib import Path
|
12 |
| -from typing import Dict, List |
| 12 | + |
| 13 | +########################################################################## |
| 14 | +### unpack packed weights ### |
| 15 | + |
| 16 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
13 | 17 |
|
14 | 18 | import torch
|
| 19 | +import torch.nn.functional as F |
| 20 | + |
| 21 | + |
| 22 | +def unpack_packed_weights( |
| 23 | + packed_weights: Dict[str, Any], |
| 24 | + packed_linear: Callable, |
| 25 | + input_dtype: torch.dtype, |
| 26 | + unpacked_dims: Tuple, |
| 27 | +) -> torch.Tensor: |
| 28 | + """Given a packed weight matrix `packed_weights`, a Callable |
| 29 | + implementing a packed linear function for the packed format, and the |
| 30 | + unpacked dimensions of the weights, recreate the unpacked weight |
| 31 | + matrix. In addition to the packed weights, as a dictionary to specify |
| 32 | + whatever arguments the packed routine expects, we also need the input |
| 33 | + data type because packing may depend on input dtype, or only some |
| 34 | + input dtypes may be supported. We also need the dimensions of the |
| 35 | + unpacked matrix. At present, this does not handle padding, but that will |
| 36 | + be straightforward to add. Similarly, the same approach can be used |
| 37 | + for both linear and mm operators. |
| 38 | +
|
| 39 | + Args: |
| 40 | + packed_weights: Dict[str, Any], |
| 41 | + packed_linear: Callable, |
| 42 | + input_dtype: torch.dtype, |
| 43 | + unpacked_dims: Optional[Tuple]=None |
| 44 | +
|
| 45 | + Example usage: |
| 46 | + packed_weights = { |
| 47 | + "weight" : weight_int4pack, |
| 48 | + "qGroupSize": groupsize, |
| 49 | + "scales_and_zeros": scales_and_zeros |
| 50 | + } |
| 51 | + unpacked_weights = unpack_packed_weights( |
| 52 | + _weight_int4pack_linear, |
| 53 | + packed_weights, |
| 54 | + torch.bfloat6, |
| 55 | + (256, 1024), |
| 56 | + ) |
| 57 | +
|
| 58 | +
|
| 59 | + """ |
| 60 | + assert len(unpacked_dims) == 2, "unpacked_dims must be a tuple of length 2" |
| 61 | + cols = unpacked_dims[1] |
| 62 | + |
| 63 | + unpacked_weights = packed_linear( |
| 64 | + torch.eye(cols, dtype=input_dtype), **packed_weights |
| 65 | + ).transpose(0, 1) |
| 66 | + return unpacked_weights |
15 | 67 |
|
16 | 68 |
|
17 | 69 | ##########################################################################
|
|
0 commit comments