-
Notifications
You must be signed in to change notification settings - Fork 250
Update int4pack related in torchchat gguf #1404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
29d4ab7
b884e29
c7ccb44
f60594f
e7b6f14
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,8 @@ | |
pack_scales_and_zeros, | ||
) | ||
|
||
from torchao.dtypes.utils import is_device | ||
|
||
|
||
logger: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
@@ -128,6 +130,7 @@ def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsiz | |
groupsize, | ||
scales_and_zeros, | ||
) | ||
|
||
new_shape = origin_input_size[:-1] + (out_features,) | ||
c = c.reshape(new_shape) | ||
return c | ||
|
@@ -178,16 +181,27 @@ def __init__( | |
), "must specify both weights and scales_and_zeros, or neither" | ||
|
||
if weight is None: | ||
weight = torch.empty( | ||
( | ||
out_features // 8, | ||
in_features // (inner_k_tiles * 16), | ||
32, | ||
inner_k_tiles // 2, | ||
), | ||
dtype=torch.int32, | ||
device=device, | ||
) | ||
if is_device(device, "cpu"): | ||
weight = torch.empty( | ||
( | ||
out_features, | ||
in_features // 2, | ||
), | ||
dtype=torch.uint8, | ||
device=device, | ||
) | ||
else: | ||
weight = torch.empty( | ||
( | ||
out_features // 8, | ||
in_features // (inner_k_tiles * 16), | ||
32, | ||
inner_k_tiles // 2, | ||
), | ||
dtype=torch.int32, | ||
device=device, | ||
) | ||
|
||
scales_and_zeros = torch.empty( | ||
(in_features // groupsize, out_features, 2), | ||
dtype=get_precision(), | ||
|
@@ -223,12 +237,17 @@ def _prepare_weight_and_scales_and_zeros( | |
weight_int32, scales_and_zeros = group_quantize_tensor( | ||
weight_bf16, n_bit=4, groupsize=groupsize | ||
) | ||
weight_uint8 = (weight_int32[::, ::2] << 4 | weight_int32[::, 1::2]).to( | ||
torch.uint8 | ||
) | ||
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( | ||
weight_uint8, inner_k_tiles | ||
) | ||
if is_device(weight_int32.device.type, "cpu"): | ||
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( | ||
weight_int32, inner_k_tiles | ||
) | ||
else: | ||
weight_uint8 = (weight_int32[::, ::2] << 4 | weight_int32[::, 1::2]).to( | ||
torch.uint8 | ||
) | ||
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( | ||
weight_uint8, inner_k_tiles | ||
) | ||
return weight_int4pack, scales_and_zeros | ||
|
||
@classmethod | ||
|
@@ -609,17 +628,14 @@ def load_model_and_state_dict( | |
if load_state_dict: | ||
q, s, z = Q4_0.unpack(t) | ||
scales_and_zeros = pack_scales_and_zeros(s, z) | ||
q_uint8 = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8) | ||
|
||
if torch.device(device).type == "cpu": | ||
weight_int4pack = ( | ||
torch.ops.aten._convert_weight_to_int4pack_for_cpu( | ||
q, inner_k_tiles | ||
) | ||
if is_device(q.device.type, "cpu"): | ||
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( | ||
q, inner_k_tiles | ||
) | ||
else: | ||
q_tmp = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8) | ||
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( | ||
q_uint8, inner_k_tiles | ||
q_tmp, inner_k_tiles | ||
) | ||
state_dict[f"{fqn}.weight"] = weight_int4pack | ||
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros | ||
|
@@ -632,7 +648,7 @@ def load_model_and_state_dict( | |
in_features=in_features, | ||
out_features=out_features, | ||
bias=False, | ||
device="meta", | ||
device="cpu", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's keep this as a meta device as long as we can There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now only CPU acts different from cuda and meta. https://github.com/pytorch/torchchat/pull/1404/files/b884e295a164fa0b8cd172196e4409e51315567b#diff-28cab20c48af32e561f6e95cec7d029fa076708223a00d64afa80ad62b9b52a4R192 Use device meta here cannot tell the right shape of weight. |
||
groupsize=Q4_0.groupsize, | ||
inner_k_tiles=inner_k_tiles, | ||
), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice