Skip to content

Commit f38d3f4

Browse files
committed
gguf fix
1 parent e4079f0 commit f38d3f4

File tree

4 files changed

+11
-3
lines changed

4 files changed

+11
-3
lines changed

.github/workflows/pull.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,7 @@ jobs:
731731
732732
git clone https://github.com/ggerganov/llama.cpp.git
733733
pushd llama.cpp
734+
git checkout 64ed2091b24b2f9747148fdf49a34ed5938762c3
734735
make
735736
popd
736737

.watchman-cookie-jackkhuu-mbp-1567-1101

Whitespace-only changes.

torchchat/cli/builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def _load_model_gguf(builder_args: BuilderArgs) -> Model:
373373
kwargs = {}
374374
else:
375375
kwargs = builder_args.gguf_kwargs
376+
kwargs.setdefault("device", builder_args.device)
376377
model = Model.from_gguf(builder_args.gguf_path, **kwargs)
377378
return model
378379

torchchat/utils/gguf_loader.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,7 @@ def load_model_and_state_dict(
570570
load_state_dict: bool = True,
571571
load_as_quantized: bool = True,
572572
inner_k_tiles=8,
573+
device="cpu",
573574
) -> torch.nn.Module:
574575
"""
575576
Parses the GGUF file and returns an nn.Module on meta device along with a state_dict
@@ -609,9 +610,14 @@ def load_model_and_state_dict(
609610
q, s, z = Q4_0.unpack(t)
610611
scales_and_zeros = pack_scales_and_zeros(s, z)
611612
q_uint8 = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8)
612-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
613-
q_uint8, inner_k_tiles
614-
)
613+
if torch.device(device).type == "cpu":
614+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
615+
q_uint8, inner_k_tiles
616+
)
617+
else:
618+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
619+
q_uint8, inner_k_tiles
620+
)
615621
state_dict[f"{fqn}.weight"] = weight_int4pack
616622
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros
617623

0 commit comments

Comments
 (0)