Skip to content

Commit ff1002b

Browse files
committed
Add line about state dict
1 parent a47d3b3 commit ff1002b

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

doc/api/training/smp_versions/v1.1.0/smd_model_parallel_pytorch.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,9 @@ This API document assumes you use the following import statements in your traini
265265
Returns the ``state_dict`` that contains optimizer state for the entire model.
266266
It first collects the ``local_state_dict`` and gathers and merges
267267
the ``local_state_dict`` from all ``mp_rank``s to create a full
268-
``state_dict``.
268+
``state_dict``. Please note that this needs to be called on all ranks with
269+
``dp_rank()==0`` to ensure the gather happens properly.
270+
If it is only called on all such ranks, it can hang.
269271

270272
.. function:: load_state_dict( )
271273
:noindex:

doc/api/training/smp_versions/v1.2.0/smd_model_parallel_pytorch.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,9 @@ This API document assumes you use the following import statements in your traini
232232
Returns the ``state_dict`` that contains parameters
233233
for the entire model. It first collects the \ ``local_state_dict``  and
234234
gathers and merges the \ ``local_state_dict`` from all ``mp_rank``\ s to
235-
create a full ``state_dict``.
235+
create a full ``state_dict``. Please note that this needs to be called on all ranks with
236+
``dp_rank()==0`` to ensure the gather happens properly.
237+
If it is only called on all such ranks, it can hang.
236238

237239
.. function:: load_state_dict( )
238240

doc/api/training/smp_versions/v1.3.0/smd_model_parallel_pytorch.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,9 @@ This API document assumes you use the following import statements in your traini
232232
Returns the ``state_dict`` that contains parameters
233233
for the entire model. It first collects the \ ``local_state_dict``  and
234234
gathers and merges the \ ``local_state_dict`` from all ``mp_rank``\ s to
235-
create a full ``state_dict``.
235+
create a full ``state_dict``. Please note that this needs to be called on all ranks with
236+
``dp_rank()==0`` to ensure the gather happens properly.
237+
If it is only called on all such ranks, it can hang.
236238

237239
.. function:: load_state_dict( )
238240

0 commit comments

Comments
 (0)