Skip to content

Commit c5d5c1f

Browse files
authored
del logits=(bs, seq_len, vocab_size) to save 3.9G memory (#391)
logits=(bs, seq_len, vocab_size). call `del logits` to free it before backward <img width="1607" alt="Screenshot 2024-06-12 at 11 10 36 AM" src="https://github.com/pytorch/torchtitan/assets/134637289/82db2792-59a3-40c4-9591-842be3dd9284">
1 parent e858ab4 commit c5d5c1f

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,9 @@ def loss_fn(pred, labels):
351351
with loss_parallel_ctx():
352352
pred = model(input_ids)
353353
loss = loss_fn(pred, labels)
354+
# pred.shape=(bs, seq_len, vocab_size)
355+
# need to free to before bwd to avoid peaking memory
356+
del pred
354357
loss.backward()
355358

356359
# clip gradients

0 commit comments

Comments
 (0)