Skip to content

Commit 25da485

Browse files
committed
Naive gguf int4wo attempt
1 parent aae4eb3 commit 25da485

File tree

1 file changed

+4
-35
lines changed

1 file changed

+4
-35
lines changed

torchchat/utils/gguf_loader.py

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -122,21 +122,7 @@ def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsiz
122122
input.dtype
123123
) # cast back to input.dtype
124124
else:
125-
# copied from https://github.com/pytorch/pytorch/blob/0052943bee624c06d8c36a371efdf7b56972ad9e/torch/_meta_registrations.py#L3308
126-
def meta__weight_int4pack_mm(x, w, q_group_size, q_scale_and_zeros):
127-
torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
128-
torch._check(w.dim() == 4, lambda: "w must be a 4D tensor")
129-
torch._check(
130-
x.dtype in [torch.float32, torch.float16, torch.bfloat16],
131-
lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
132-
)
133-
torch._check(
134-
w.dtype is torch.int32,
135-
lambda: f"expected w to be int32, got {w.dtype}",
136-
)
137-
return x.new_empty(x.size(0), w.size(0) * 8, dtype=x.dtype)
138-
139-
c = meta__weight_int4pack_mm(
125+
c = torch.ops.aten._weight_int4pack_mm_for_cpu(
140126
input,
141127
weight_int4pack,
142128
groupsize,
@@ -626,27 +612,10 @@ def load_model_and_state_dict(
626612
q_uint8 = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8)
627613

628614
if torch.device(device).type == "cpu":
629-
# Copied from https://github.com/pytorch/pytorch/blob/0052943bee624c06d8c36a371efdf7b56972ad9e/torch/_meta_registrations.py#L3273
630-
def meta__convert_weight_to_int4pack(w, inner_k_tiles):
631-
torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
632-
torch._check(
633-
w.dtype is torch.uint8,
634-
lambda: f"expected w to be uint8, got {w.dtype}",
615+
weight_int4pack = (
616+
torch.ops.aten._convert_weight_to_int4pack_for_cpu(
617+
q, inner_k_tiles
635618
)
636-
n = w.size(0)
637-
k = w.size(1) * 2 # w is [n][k / 2] uint8
638-
return w.new_empty(
639-
(
640-
n // 8,
641-
k // (inner_k_tiles * 16),
642-
32,
643-
inner_k_tiles // 2,
644-
),
645-
dtype=torch.int32,
646-
)
647-
648-
weight_int4pack = meta__convert_weight_to_int4pack(
649-
q_uint8, inner_k_tiles
650619
)
651620
else:
652621
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(

0 commit comments

Comments
 (0)