Skip to content

Commit d5f898d

Browse files
lucylqfacebook-github-bot
authored andcommitted
Create model with device='meta'
Summary: See discussion: D54825007 Two optimizations: 1. Use `mmap=True` to load the checkpoint. 2. Create model with device="meta". Tensors created in this context do not carry data. Previously, llama7b model was created with fp32 (default), using up 25GB ram. With device="meta", tensors are assigned only when we load the state dict. - Note: non-persistent buffers and tensors that do not have keys in the state dict will be created with device="meta" as well. These have to be manually initialized when creating the model. See D46784302. Checkpoint loading time: 10s -> 0.011s Peak memory usage: [37.8GB](https://lookaside.facebook.com/intern/diff/file/data/?number=1467921211&download=1) ->[25.5GB](https://lookaside.facebook.com/intern/diff/file/data/?number=1468357208&download=1) Model creation time:[ 77s](https://lookaside.facebook.com/intern/diff/file/data/?number=1468360493&download=1) -> [11.6s](https://lookaside.facebook.com/intern/diff/file/data/?number=1468364061&download=1) Follow on: iterate over params/buffers and initialize uninitialized tensors (instead of manually initializing, which is model-specific) T182328293 thanks iseeyuan for the tips: https://pytorch.org/tutorials/recipes/recipes/module_load_state_dict_tips.html bypass-github-export-checks Reviewed By: iseeyuan Differential Revision: D54871495 fbshipit-source-id: f6a8d01c88ce45bb5d2358cc522ba81c0e1fbd5b
1 parent d0512b6 commit d5f898d

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

examples/models/llama2/llama_transformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
110110

111111

112112
def precompute_freqs_cis(dim: int, end: int, theta: float):
113-
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
113+
freqs = 1.0 / (
114+
theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
115+
)
114116
t = torch.arange(end, device=freqs.device) # pyre-ignore
115117
freqs = torch.outer(t, freqs).float() # pyre-ignore
116118
freqs_cos = torch.cos(freqs)
@@ -171,6 +173,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
171173
mask = torch.full(
172174
(1, 1, args.max_seq_len, args.max_seq_len),
173175
float("-inf"),
176+
device="cpu",
174177
)
175178

176179
mask = torch.triu(mask, diagonal=1)

examples/models/llama2/model.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self, **kwargs):
7070
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
7171
device = "cpu"
7272
# flake8: noqa: TOR102
73-
checkpoint = torch.load(checkpoint_path, map_location=device)
73+
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
7474
fairseq2_checkpoint = kwargs.get("fairseq2", False)
7575
if fairseq2_checkpoint:
7676
print("Using fairseq2 checkpoint")
@@ -130,7 +130,11 @@ def __init__(self, **kwargs):
130130
for key, weights in checkpoint.items():
131131
print(f"{key} : {weights.numel()} : {weights.size()}")
132132
print("============= /weights ================")
133-
self.model_ = Transformer(model_args)
133+
134+
# Within the device="meta" context, tensors that are created do not carry data.
135+
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
136+
with torch.device("meta"):
137+
self.model_ = Transformer(model_args)
134138

135139
if "int8" in str(checkpoint_path):
136140
print("Using int8 weight-only quantization!")
@@ -142,11 +146,16 @@ def __init__(self, **kwargs):
142146
print("Using int4 weight-only quantization!")
143147
from .quantize import Int8DynActInt4WeightQuantHandler
144148

145-
simple_quantizer = INt8dynactint4weightquanthandler(self.model_)
149+
simple_quantizer = Int8DynActInt4WeightQuantHandler(self.model_)
146150
self.model_ = simple_quantizer.convert_for_runtime()
147151

152+
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
153+
# Because we are using device="meta", tensors do not have memory associated with them
154+
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
148155
self.model_.load_state_dict(
149-
checkpoint, strict=False
156+
checkpoint,
157+
strict=False,
158+
assign=True,
150159
) # self.model_ = Transformer(gptconf)
151160

152161
def get_eager_model(self):

0 commit comments

Comments
 (0)