Skip to content

Commit aae4eb3

Browse files
committed
Attempt to account for int4wo packing pt#139611
1 parent d67eb86 commit aae4eb3

File tree

1 file changed

+37
-4
lines changed

1 file changed

+37
-4
lines changed

torchchat/utils/gguf_loader.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,21 @@ def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsiz
122122
input.dtype
123123
) # cast back to input.dtype
124124
else:
125-
c = torch.ops.aten._weight_int4pack_mm(
125+
# copied from https://github.com/pytorch/pytorch/blob/0052943bee624c06d8c36a371efdf7b56972ad9e/torch/_meta_registrations.py#L3308
126+
def meta__weight_int4pack_mm(x, w, q_group_size, q_scale_and_zeros):
127+
torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
128+
torch._check(w.dim() == 4, lambda: "w must be a 4D tensor")
129+
torch._check(
130+
x.dtype in [torch.float32, torch.float16, torch.bfloat16],
131+
lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
132+
)
133+
torch._check(
134+
w.dtype is torch.int32,
135+
lambda: f"expected w to be int32, got {w.dtype}",
136+
)
137+
return x.new_empty(x.size(0), w.size(0) * 8, dtype=x.dtype)
138+
139+
c = meta__weight_int4pack_mm(
126140
input,
127141
weight_int4pack,
128142
groupsize,
@@ -610,10 +624,29 @@ def load_model_and_state_dict(
610624
q, s, z = Q4_0.unpack(t)
611625
scales_and_zeros = pack_scales_and_zeros(s, z)
612626
q_uint8 = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8)
613-
627+
614628
if torch.device(device).type == "cpu":
615-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
616-
q_uint8, inner_k_tiles
629+
# Copied from https://github.com/pytorch/pytorch/blob/0052943bee624c06d8c36a371efdf7b56972ad9e/torch/_meta_registrations.py#L3273
630+
def meta__convert_weight_to_int4pack(w, inner_k_tiles):
631+
torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
632+
torch._check(
633+
w.dtype is torch.uint8,
634+
lambda: f"expected w to be uint8, got {w.dtype}",
635+
)
636+
n = w.size(0)
637+
k = w.size(1) * 2 # w is [n][k / 2] uint8
638+
return w.new_empty(
639+
(
640+
n // 8,
641+
k // (inner_k_tiles * 16),
642+
32,
643+
inner_k_tiles // 2,
644+
),
645+
dtype=torch.int32,
646+
)
647+
648+
weight_int4pack = meta__convert_weight_to_int4pack(
649+
q_uint8, inner_k_tiles
617650
)
618651
else:
619652
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(

0 commit comments

Comments
 (0)