Skip to content

Commit bbb46b1

Browse files
committed
fix kwargs and descriptions
1 parent 907f4ff commit bbb46b1

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ smdistributed.modelparallel.torch APIs for Saving and Loading
729729
* ``num_kept_partial_checkpoints`` (int) (default: None): The maximum number
730730
of partial checkpoints to keep on disk.
731731

732-
.. function:: smdistributed.modelparallel.torch.resume_from_checkpoint(path, tag=None, partial=True, strict=True, load_optimizer_states=True, translate_function=None)
732+
.. function:: smdistributed.modelparallel.torch.resume_from_checkpoint(path, tag=None, partial=True, strict=True, load_optimizer=True, load_optimizer_states=True, translate_function=None)
733733

734734
While :class:`smdistributed.modelparallel.torch.load` loads saved
735735
model and optimizer objects, this function resumes from a saved checkpoint file.
@@ -742,7 +742,16 @@ smdistributed.modelparallel.torch APIs for Saving and Loading
742742
* ``partial`` (boolean) (default: True): Whether to load the partial checkpoint.
743743
* ``strict`` (boolean) (default: True): Load with strict load, no extra key or
744744
missing key is allowed.
745-
* ``load_optimizer_states`` (boolean) (default: True): Whether to load ``optimizer_states``.
745+
* ``load_optimizer`` (boolean) (default: True): Whether to load ``optimizer``.
746+
* ``load_sharded_optimizer_state`` (boolean) (default: True): Whether to load
747+
the sharded optimizer state of a model.
748+
It can be used only when you activate
749+
the `sharded data parallelism
750+
<https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-extended-features-pytorch-sharded-data-parallelism.html>`_
751+
feature of the SageMaker model parallel library.
752+
When this is ``False``, the library only loads the FP16
753+
states, such as FP32 master parameters and the loss scaling factor,
754+
not the sharded optimizer states.
746755
* ``translate_function`` (function) (default: None): function to translate the full
747756
checkpoint into smdistributed.modelparallel format.
748757
For supported models, this is not required.

0 commit comments

Comments
 (0)