Skip to content

Commit 8580712

Browse files
mikekgfbmalfet
authored andcommitted
add unpacking support (#525)
* add unpacking support * fix typos and linter
1 parent 6732127 commit 8580712

File tree

1 file changed

+53
-1
lines changed

1 file changed

+53
-1
lines changed

build/utils.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,61 @@
99
import logging
1010
import os
1111
from pathlib import Path
12-
from typing import Dict, List
12+
13+
##########################################################################
14+
### unpack packed weights ###
15+
16+
from typing import Any, Callable, Dict, List, Optional, Tuple
1317

1418
import torch
19+
import torch.nn.functional as F
20+
21+
22+
def unpack_packed_weights(
23+
packed_weights: Dict[str, Any],
24+
packed_linear: Callable,
25+
input_dtype: torch.dtype,
26+
unpacked_dims: Tuple,
27+
) -> torch.Tensor:
28+
"""Given a packed weight matrix `packed_weights`, a Callable
29+
implementing a packed linear function for the packed format, and the
30+
unpacked dimensions of the weights, recreate the unpacked weight
31+
matrix. In addition to the packed weights, as a dictionary to specify
32+
whatever arguments the packed routine expects, we also need the input
33+
data type because packing may depend on input dtype, or only some
34+
input dtypes may be supported. We also need the dimensions of the
35+
unpacked matrix. At present, this does not handle padding, but that will
36+
be straightforward to add. Similarly, the same approach can be used
37+
for both linear and mm operators.
38+
39+
Args:
40+
packed_weights: Dict[str, Any],
41+
packed_linear: Callable,
42+
input_dtype: torch.dtype,
43+
unpacked_dims: Optional[Tuple]=None
44+
45+
Example usage:
46+
packed_weights = {
47+
"weight" : weight_int4pack,
48+
"qGroupSize": groupsize,
49+
"scales_and_zeros": scales_and_zeros
50+
}
51+
unpacked_weights = unpack_packed_weights(
52+
_weight_int4pack_linear,
53+
packed_weights,
54+
torch.bfloat6,
55+
(256, 1024),
56+
)
57+
58+
59+
"""
60+
assert len(unpacked_dims) == 2, "unpacked_dims must be a tuple of length 2"
61+
cols = unpacked_dims[1]
62+
63+
unpacked_weights = packed_linear(
64+
torch.eye(cols, dtype=input_dtype), **packed_weights
65+
).transpose(0, 1)
66+
return unpacked_weights
1567

1668

1769
##########################################################################

0 commit comments

Comments
 (0)