Skip to content

Commit 383565d

Browse files
SeanNarenrohitgr7
andauthored
Update DeepSpeed docs (#6528)
* Clean up docs and add some explicitness around stages * Apply suggestions from code review Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent c48fc6a commit 383565d

File tree

1 file changed

+45
-9
lines changed

1 file changed

+45
-9
lines changed

docs/source/advanced/multi_gpu.rst

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -697,24 +697,23 @@ To use DeepSpeed, you first need to install DeepSpeed using the commands below.
697697

698698
.. code-block:: bash
699699
700-
pip install deepspeed mpi4py
700+
pip install deepspeed
701701
702702
If you run into an issue with the install or later in training, ensure that the CUDA version of the pytorch you've installed matches your locally installed CUDA (you can see which one has been recognized by running ``nvcc --version``).
703-
Additionally if you run into any issues installing m4py, ensure you have openmpi installed using ``sudo apt install libopenmpi-dev`` or ``brew install mpich`` before running ``pip install mpi4py``.
704703

705704
.. note::
706705
Currently ``resume_from_checkpoint`` and manual optimization are not supported.
707706

708707
DeepSpeed currently only supports single optimizer, single scheduler within the training loop.
709708

710-
ZeRO-Offload
711-
""""""""""""
709+
DeepSpeed ZeRO Stage 2
710+
""""""""""""""""""""""
712711

713-
Below we show an example of running `ZeRO-Offload <https://www.deepspeed.ai/tutorials/zero-offload/>`_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption.
714-
For even more speed benefit, they offer an optimized CPU version of ADAM to run the offloaded computation, which is faster than the standard PyTorch implementation. By default we enable ZeRO-Offload.
712+
By default, we enable `DeepSpeed ZeRO Stage 2 <https://www.deepspeed.ai/tutorials/zero/#zero-overview>`_, which partitions your optimizer states (Stage 1) and your gradients (Stage 2) across your GPUs to reduce memory. In most cases, this is more efficient or at parity with DDP, primarily due to the optimized custom communications written by the DeepSpeed team.
713+
As a result, benefits can also be seen on a single GPU. Do note that the default bucket sizes allocate around ``3.6GB`` of VRAM to use during distributed communications, which can be tweaked when instantiating the plugin described in a few sections below.
715714

716715
.. note::
717-
To use ZeRO-Offload, you must use ``precision=16`` or set precision via `the DeepSpeed config. <https://www.deepspeed.ai/docs/config-json/#fp16-training-options>`_.
716+
To use ZeRO, you must use ``precision=16``.
718717

719718
.. code-block:: python
720719
@@ -725,6 +724,24 @@ For even more speed benefit, they offer an optimized CPU version of ADAM to run
725724
trainer.fit(model)
726725
727726
727+
DeepSpeed ZeRO Stage 2 Offload
728+
""""""""""""""""""""""""""""""
729+
730+
Below we show an example of running `ZeRO-Offload <https://www.deepspeed.ai/tutorials/zero-offload/>`_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption.
731+
732+
.. note::
733+
To use ZeRO-Offload, you must use ``precision=16``.
734+
735+
.. code-block:: python
736+
737+
from pytorch_lightning import Trainer
738+
from pytorch_lightning.plugins import DeepSpeedPlugin
739+
740+
model = MyModel()
741+
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(cpu_offload=True), precision=16)
742+
trainer.fit(model)
743+
744+
728745
This can also be done via the command line using a Pytorch Lightning script:
729746

730747
.. code-block:: bash
@@ -740,7 +757,7 @@ You can also modify the ZeRO-Offload parameters via the plugin as below.
740757
from pytorch_lightning.plugins import DeepSpeedPlugin
741758
742759
model = MyModel()
743-
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(allgather_bucket_size=5e8, reduce_bucket_size=5e8), precision=16)
760+
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(cpu_offload=True, allgather_bucket_size=5e8, reduce_bucket_size=5e8), precision=16)
744761
trainer.fit(model)
745762
746763
@@ -752,11 +769,30 @@ You can also modify the ZeRO-Offload parameters via the plugin as below.
752769

753770
The plugin sets a reasonable default of ``2e8``, which should work for most low VRAM GPUs (less than ``7GB``), allocating roughly ``3.6GB`` of VRAM as buffer. Higher VRAM GPUs should aim for values around ``5e8``.
754771

772+
For even more speed benefit, DeepSpeed offers an optimized CPU version of ADAM called `DeepSpeedCPUAdam <https://deepspeed.readthedocs.io/en/latest/optimizers.html#adam-cpu>`_ to run the offloaded computation, which is faster than the standard PyTorch implementation.
773+
774+
.. code-block:: python
775+
776+
import pytorch_lightning
777+
from pytorch_lightning import Trainer
778+
from pytorch_lightning.plugins import DeepSpeedPlugin
779+
from deepspeed.ops.adam import DeepSpeedCPUAdam
780+
781+
class MyModel(pl.LightningModule):
782+
...
783+
def configure_optimizers(self):
784+
# DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w)
785+
return DeepSpeedCPUAdam(self.parameters())
786+
787+
model = MyModel()
788+
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(cpu_offload=True), precision=16)
789+
trainer.fit(model)
790+
755791
756792
Custom DeepSpeed Config
757793
"""""""""""""""""""""""
758794

759-
DeepSpeed allows use of custom DeepSpeed optimizers and schedulers defined within a config file. This allows you to enable optimizers such as `1-bit Adam <https://www.deepspeed.ai/tutorials/onebit-adam/>`_.
795+
In some cases you may want to define your own DeepSpeed Config, to access all parameters defined. We've exposed most of the important parameters, however, there may be debugging parameters to enable. Also, DeepSpeed allows the use of custom DeepSpeed optimizers and schedulers defined within a config file that is supported.
760796

761797
.. note::
762798
All plugin default parameters will be ignored when a config object is passed.

0 commit comments

Comments
 (0)