@@ -70,7 +70,7 @@ def __init__(self, **kwargs):
70
70
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
71
71
device = "cpu"
72
72
# flake8: noqa: TOR102
73
- checkpoint = torch .load (checkpoint_path , map_location = device )
73
+ checkpoint = torch .load (checkpoint_path , map_location = device , mmap = True )
74
74
fairseq2_checkpoint = kwargs .get ("fairseq2" , False )
75
75
if fairseq2_checkpoint :
76
76
print ("Using fairseq2 checkpoint" )
@@ -130,7 +130,11 @@ def __init__(self, **kwargs):
130
130
for key , weights in checkpoint .items ():
131
131
print (f"{ key } : { weights .numel ()} : { weights .size ()} " )
132
132
print ("============= /weights ================" )
133
- self .model_ = Transformer (model_args )
133
+
134
+ # Within the device="meta" context, tensors that are created do not carry data.
135
+ # They possess all other metadata a tensor carries such as size, stride, requires_grad.
136
+ with torch .device ("meta" ):
137
+ self .model_ = Transformer (model_args )
134
138
135
139
if "int8" in str (checkpoint_path ):
136
140
print ("Using int8 weight-only quantization!" )
@@ -142,11 +146,16 @@ def __init__(self, **kwargs):
142
146
print ("Using int4 weight-only quantization!" )
143
147
from .quantize import Int8DynActInt4WeightQuantHandler
144
148
145
- simple_quantizer = INt8dynactint4weightquanthandler (self .model_ )
149
+ simple_quantizer = Int8DynActInt4WeightQuantHandler (self .model_ )
146
150
self .model_ = simple_quantizer .convert_for_runtime ()
147
151
152
+ # assign=True: load params/buffers by assignment instead of performing an in-place copy.
153
+ # Because we are using device="meta", tensors do not have memory associated with them
154
+ # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
148
155
self .model_ .load_state_dict (
149
- checkpoint , strict = False
156
+ checkpoint ,
157
+ strict = False ,
158
+ assign = True ,
150
159
) # self.model_ = Transformer(gptconf)
151
160
152
161
def get_eager_model (self ):
0 commit comments