Skip to content

Commit cc5e7f9

Browse files
committed
feat(granite): Add support for loading state_dict from safetensors
Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 6116ac8 commit cc5e7f9

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

torchchat/cli/convert_hf_checkpoint.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,33 @@ def permute(w, n_heads):
102102

103103
merged_result = {}
104104
for file in sorted(bin_files):
105-
state_dict = torch.load(
105+
106+
# The state_dict can be loaded from either a torch zip file or
107+
# safetensors. We take our best guess from the name and try all
108+
# possibilities
109+
load_pt_mmap = lambda: torch.load(
106110
str(file), map_location="cpu", mmap=True, weights_only=True
107111
)
112+
load_pt_no_mmap = lambda: torch.load(
113+
str(file), map_location="cpu", mmap=False, weights_only=True
114+
)
115+
def load_safetensors():
116+
import safetensors.torch
117+
with open(file, "rb") as handle:
118+
return safetensors.torch.load(handle.read())
119+
if "safetensors" in str(file):
120+
loaders = [load_safetensors, load_pt_mmap, load_pt_no_mmap]
121+
else:
122+
loaders = [load_pt_mmap, load_pt_no_mmap, load_safetensors]
123+
124+
state_dict = None
125+
for loader in loaders:
126+
try:
127+
state_dict = loader()
128+
break
129+
except Exception:
130+
continue
131+
assert state_dict is not None, f"Unable to load tensors from {file}"
108132
merged_result.update(state_dict)
109133
final_result = {}
110134
for key, value in merged_result.items():

0 commit comments

Comments
 (0)