@@ -102,9 +102,33 @@ def permute(w, n_heads):
102
102
103
103
merged_result = {}
104
104
for file in sorted (bin_files ):
105
- state_dict = torch .load (
105
+
106
+ # The state_dict can be loaded from either a torch zip file or
107
+ # safetensors. We take our best guess from the name and try all
108
+ # possibilities
109
+ load_pt_mmap = lambda : torch .load (
106
110
str (file ), map_location = "cpu" , mmap = True , weights_only = True
107
111
)
112
+ load_pt_no_mmap = lambda : torch .load (
113
+ str (file ), map_location = "cpu" , mmap = False , weights_only = True
114
+ )
115
+ def load_safetensors ():
116
+ import safetensors .torch
117
+ with open (file , "rb" ) as handle :
118
+ return safetensors .torch .load (handle .read ())
119
+ if "safetensors" in str (file ):
120
+ loaders = [load_safetensors , load_pt_mmap , load_pt_no_mmap ]
121
+ else :
122
+ loaders = [load_pt_mmap , load_pt_no_mmap , load_safetensors ]
123
+
124
+ state_dict = None
125
+ for loader in loaders :
126
+ try :
127
+ state_dict = loader ()
128
+ break
129
+ except Exception :
130
+ continue
131
+ assert state_dict is not None , f"Unable to load tensors from { file } "
108
132
merged_result .update (state_dict )
109
133
final_result = {}
110
134
for key , value in merged_result .items ():
0 commit comments