-
Notifications
You must be signed in to change notification settings - Fork 4.2k
2.4.0 Updates to DCP #2968
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
2.4.0 Updates to DCP #2968
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
f4ecb4c
adds async save and stateful info
LucasLLC 31f1066
adds format utils
LucasLLC a941fb0
type updates
LucasLLC 3c86d8f
Merge branch 'main' into dcp_async_save
LucasLLC 8d6e7dc
removes loss fn unused
LucasLLC 16a2f05
adds async to index
LucasLLC ea1a89e
Merge branch 'main' into dcp_async_save
svekars 0e3c3ec
formatting updates
LucasLLC 54a6a30
formatting
LucasLLC e474630
Merge branch 'dcp_async_save' of github.com:pytorch/tutorials into dc…
LucasLLC 25ea481
Update recipes_source/distributed_async_checkpoint_recipe.rst
LucasLLC e6b3ac2
Update recipes_source/distributed_async_checkpoint_recipe.rst
LucasLLC f4ec793
Update recipes_source/distributed_async_checkpoint_recipe.rst
LucasLLC 0a483b1
Merge branch 'main' into dcp_async_save
LucasLLC 51a9b61
spelling
LucasLLC 6acfa55
Merge branch 'dcp_async_save' of github.com:pytorch/tutorials into dc…
LucasLLC ed9c46a
Merge branch 'main' into dcp_async_save
svekars File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,291 @@ | ||
Asynchronous Saving with Distributed Checkpoint (DCP) | ||
===================================================== | ||
|
||
**Author:** `Lucas Pasqualin <https://github.com/lucasllc>`__, `Iris Zhang <https://github.com/wz337>`__, `Rodrigo Kumpera <https://github.com/kumpera>`__, `Chien-Chin Huang <https://github.com/fegin>`__ | ||
|
||
Checkpointing is often a bottle-neck in the critical path for distributed training workloads, incurring larger and larger costs as both model and world sizes grow. | ||
One excellent strategy for offsetting this cost is to checkpoint in parallel, asynchronously. Below, we expand the save example | ||
from the `Getting Started with Distributed Checkpoint Tutorial <https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_checkpoint_recipe.rst>`__ | ||
to show how this can be integrated quite easily with ``torch.distributed.checkpoint.async_save``. | ||
|
||
|
||
.. grid:: 2 | ||
|
||
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn | ||
:class-card: card-prerequisites | ||
|
||
* How to use DCP to generate checkpoints in parallel | ||
* Effective strategies to optimize performance | ||
|
||
LucasLLC marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites | ||
:class-card: card-prerequisites | ||
|
||
* PyTorch v2.4.0 or later | ||
* `Getting Started with Distributed Checkpoint Tutorial <https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_checkpoint_recipe.rst>`__ | ||
|
||
|
||
Asynchronous Checkpointing Overview | ||
------------------------------------ | ||
Before getting started with Asynchronous Checkpointing, it's important to understand it's differences and limitations as compared to synchronous checkpointing. | ||
Specifically: | ||
|
||
* Memory requirements - Asynchronous checkpointing works by first copying models into internal CPU-buffers. | ||
This is helpful since it ensures model and optimizer weights are not changing while the model is still checkpointing, | ||
but does raise CPU memory by a factor of ``checkpoint_size_per_rank X number_of_ranks``. Additionally, users should take care to understand | ||
the memory constraints of their systems. Specifically, pinned memory implies the usage of ``page-lock`` memory, which can be scarce as compared to | ||
``pageable`` memory. | ||
|
||
* Checkpoint Management - Since checkpointing is asynchronous, it is up to the user to manage concurrently run checkpoints. In general, users can | ||
employ their own management strategies by handling the future object returned form ``async_save``. For most users, we recommend limiting | ||
checkpoints to one asynchronous request at a time, avoiding additional memory pressure per request. | ||
|
||
|
||
|
||
.. code-block:: python | ||
|
||
import os | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.distributed.checkpoint as dcp | ||
import torch.multiprocessing as mp | ||
import torch.nn as nn | ||
|
||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict | ||
from torch.distributed.checkpoint.stateful import Stateful | ||
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType | ||
|
||
CHECKPOINT_DIR = "checkpoint" | ||
|
||
|
||
class AppState(Stateful): | ||
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant | ||
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the | ||
dcp.save/load APIs. | ||
|
||
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model | ||
and optimizer. | ||
""" | ||
|
||
def __init__(self, model, optimizer=None): | ||
self.model = model | ||
self.optimizer = optimizer | ||
|
||
def state_dict(self): | ||
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT | ||
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) | ||
return { | ||
"model": model_state_dict, | ||
"optim": optimizer_state_dict | ||
} | ||
|
||
def load_state_dict(self, state_dict): | ||
# sets our state dicts on the model and optimizer, now that we've loaded | ||
set_state_dict( | ||
self.model, | ||
self.optimizer, | ||
model_state_dict=state_dict["model"], | ||
optim_state_dict=state_dict["optim"] | ||
) | ||
|
||
class ToyModel(nn.Module): | ||
def __init__(self): | ||
super(ToyModel, self).__init__() | ||
self.net1 = nn.Linear(16, 16) | ||
self.relu = nn.ReLU() | ||
self.net2 = nn.Linear(16, 8) | ||
|
||
def forward(self, x): | ||
return self.net2(self.relu(self.net1(x))) | ||
|
||
|
||
def setup(rank, world_size): | ||
os.environ["MASTER_ADDR"] = "localhost" | ||
os.environ["MASTER_PORT"] = "12355 " | ||
|
||
# initialize the process group | ||
dist.init_process_group("nccl", rank=rank, world_size=world_size) | ||
torch.cuda.set_device(rank) | ||
|
||
|
||
def cleanup(): | ||
dist.destroy_process_group() | ||
|
||
|
||
def run_fsdp_checkpoint_save_example(rank, world_size): | ||
print(f"Running basic FSDP checkpoint saving example on rank {rank}.") | ||
setup(rank, world_size) | ||
|
||
# create a model and move it to GPU with id rank | ||
model = ToyModel().to(rank) | ||
model = FSDP(model) | ||
|
||
loss_fn = nn.MSELoss() | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) | ||
|
||
checkpoint_future = None | ||
for step in range(10): | ||
optimizer.zero_grad() | ||
model(torch.rand(8, 16, device="cuda")).sum().backward() | ||
optimizer.step() | ||
|
||
# waits for checkpointing to finish if one exists, avoiding queuing more then one checkpoint request at a time | ||
if checkpoint_future is not None: | ||
checkpoint_future.result() | ||
|
||
state_dict = { "app": AppState(model, optimizer) } | ||
checkpoint_future = dcp.async_save(state_dict, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}") | ||
|
||
cleanup() | ||
|
||
|
||
if __name__ == "__main__": | ||
world_size = torch.cuda.device_count() | ||
print(f"Running async checkpoint example on {world_size} devices.") | ||
mp.spawn( | ||
run_fsdp_checkpoint_save_example, | ||
args=(world_size,), | ||
nprocs=world_size, | ||
join=True, | ||
) | ||
|
||
|
||
Even more performance with Pinned Memory | ||
----------------------------------------- | ||
If the above optimization is still not performant enough, you can take advantage of an additional optimization for GPU models which utilizes a pinned memory buffer for checkpoint staging. | ||
Specifically, this optimization attacks the main overhead of asynchronous checkpointing, which is the in-memory copying to checkpointing buffers. By maintaining a pinned memory buffer between | ||
checkpoint requests users can take advantage of direct memory access to speed up this copy. | ||
|
||
.. note:: | ||
The main drawback of this optimization is the persistence of the buffer in between checkpointing steps. Without | ||
the pinned memory optimization (as demonstrated above), any checkpointing buffers are released as soon as | ||
checkpointing is finished. With the pinned memory implementation, this buffer is maintained between steps, | ||
leading to the same | ||
peak memory pressure being sustained through the application life. | ||
|
||
|
||
.. code-block:: python | ||
|
||
import os | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.distributed.checkpoint as dcp | ||
import torch.multiprocessing as mp | ||
import torch.nn as nn | ||
|
||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict | ||
from torch.distributed.checkpoint.stateful import Stateful | ||
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType | ||
from torch.distributed.checkpoint import StorageWriter | ||
|
||
CHECKPOINT_DIR = "checkpoint" | ||
|
||
|
||
class AppState(Stateful): | ||
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant | ||
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the | ||
dcp.save/load APIs. | ||
|
||
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model | ||
and optimizer. | ||
""" | ||
|
||
def __init__(self, model, optimizer=None): | ||
self.model = model | ||
self.optimizer = optimizer | ||
|
||
def state_dict(self): | ||
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT | ||
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) | ||
return { | ||
"model": model_state_dict, | ||
"optim": optimizer_state_dict | ||
} | ||
|
||
def load_state_dict(self, state_dict): | ||
# sets our state dicts on the model and optimizer, now that we've loaded | ||
set_state_dict( | ||
self.model, | ||
self.optimizer, | ||
model_state_dict=state_dict["model"], | ||
optim_state_dict=state_dict["optim"] | ||
) | ||
|
||
class ToyModel(nn.Module): | ||
def __init__(self): | ||
super(ToyModel, self).__init__() | ||
self.net1 = nn.Linear(16, 16) | ||
self.relu = nn.ReLU() | ||
self.net2 = nn.Linear(16, 8) | ||
|
||
def forward(self, x): | ||
return self.net2(self.relu(self.net1(x))) | ||
|
||
|
||
def setup(rank, world_size): | ||
os.environ["MASTER_ADDR"] = "localhost" | ||
os.environ["MASTER_PORT"] = "12355 " | ||
|
||
# initialize the process group | ||
dist.init_process_group("nccl", rank=rank, world_size=world_size) | ||
torch.cuda.set_device(rank) | ||
|
||
|
||
def cleanup(): | ||
dist.destroy_process_group() | ||
|
||
|
||
def run_fsdp_checkpoint_save_example(rank, world_size): | ||
print(f"Running basic FSDP checkpoint saving example on rank {rank}.") | ||
setup(rank, world_size) | ||
|
||
# create a model and move it to GPU with id rank | ||
model = ToyModel().to(rank) | ||
model = FSDP(model) | ||
|
||
loss_fn = nn.MSELoss() | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) | ||
|
||
# The storage writer defines our 'staging' strategy, where staging is considered the process of copying | ||
# checkpoints to in-memory buffers. By setting `cached_state_dict=True`, we enable efficient memory copying | ||
# into a persistent buffer with pinned memory enabled. | ||
# Note: It's important that the writer persists in between checkpointing requests, since it maintains the | ||
# pinned memory buffer. | ||
writer = StorageWriter(cached_state_dict=True) | ||
checkpoint_future = None | ||
for step in range(10): | ||
optimizer.zero_grad() | ||
model(torch.rand(8, 16, device="cuda")).sum().backward() | ||
optimizer.step() | ||
|
||
state_dict = { "app": AppState(model, optimizer) } | ||
if checkpoint_future is not None: | ||
# waits for checkpointing to finish, avoiding queuing more then one checkpoint request at a time | ||
checkpoint_future.result() | ||
dcp.async_save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}") | ||
|
||
cleanup() | ||
|
||
|
||
if __name__ == "__main__": | ||
world_size = torch.cuda.device_count() | ||
print(f"Running fsdp checkpoint example on {world_size} devices.") | ||
mp.spawn( | ||
run_fsdp_checkpoint_save_example, | ||
args=(world_size,), | ||
nprocs=world_size, | ||
join=True, | ||
) | ||
LucasLLC marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
Conclusion | ||
---------- | ||
In conclusion, we have learned how to use DCP's :func:`async_save` API to generate checkpoints off the critical training path. We've also learned about the | ||
additional memory and concurrency overhead introduced by using this API, as well as additional optimizations which utilize pinned memory to speed things up | ||
even further. | ||
|
||
- `Saving and loading models tutorial <https://pytorch.org/tutorials/beginner/saving_loading_models.html>`__ | ||
- `Getting started with FullyShardedDataParallel tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.