Skip to content

documentation: Add SMP 1.2.0 API docs #2098

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api/training/smd_model_parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Select a version to see the API documentation for version. To use the library, r
.. toctree::
:maxdepth: 1

smp_versions/v1_2_0.rst
smp_versions/v1_1_0.rst

It is recommended to use this documentation alongside `SageMaker Distributed Model Parallel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ The following SageMaker distribute model parallel APIs are common across all fra


.. function:: smp.init( )
:noindex:

Initialize the library. Must be called at the beginning of training script.

.. function:: @smp.step(non_split_inputs, input_split_axes, [*args, **kwargs])
:noindex:

A decorator that must be placed over a function that represents a single
forward and backward pass (for training use cases), or a single forward
Expand Down Expand Up @@ -159,6 +161,7 @@ The following SageMaker distribute model parallel APIs are common across all fra


.. class:: StepOutput
:noindex:


A class that encapsulates all versions of a ``tf.Tensor``
Expand Down Expand Up @@ -188,27 +191,32 @@ The following SageMaker distribute model parallel APIs are common across all fra
post-processing operations on tensors.

.. data:: StepOutput.outputs
:noindex:

Returns a list of the underlying tensors, indexed by microbatch.

.. function:: StepOutput.reduce_mean( )
:noindex:

Returns a ``tf.Tensor``, ``torch.Tensor`` that averages the constituent ``tf.Tensor`` s
``torch.Tensor`` s. This is commonly used for averaging loss and gradients across microbatches.

.. function:: StepOutput.reduce_sum( )
:noindex:

Returns a ``tf.Tensor`` /
``torch.Tensor`` that sums the constituent
``tf.Tensor``\ s/\ ``torch.Tensor``\ s.

.. function:: StepOutput.concat( )
:noindex:

Returns a
``tf.Tensor``/``torch.Tensor`` that concatenates tensors along the
batch dimension using ``tf.concat`` / ``torch.cat``.

.. function:: StepOutput.stack( )
:noindex:

Applies ``tf.stack`` / ``torch.stack``
operation to the list of constituent ``tf.Tensor``\ s /
Expand All @@ -217,13 +225,15 @@ The following SageMaker distribute model parallel APIs are common across all fra
**TensorFlow-only methods**

.. function:: StepOutput.merge( )
:noindex:

Returns a ``tf.Tensor`` that
concatenates the constituent ``tf.Tensor``\ s along the batch
dimension. This is commonly used for merging the model predictions
across microbatches.

.. function:: StepOutput.accumulate(method="variable", var=None)
:noindex:

Functionally the same as ``StepOutput.reduce_mean()``. However, it is
more memory-efficient, especially for large numbers of microbatches,
Expand All @@ -249,6 +259,7 @@ The following SageMaker distribute model parallel APIs are common across all fra
ignored.

.. _mpi_basics:
:noindex:

MPI Basics
^^^^^^^^^^
Expand All @@ -271,7 +282,8 @@ The library exposes the following basic MPI primitives to its Python API:
- ``smp.get_dp_group()``: The list of ranks that hold different
replicas of the same model partition.

.. _communication_api:
.. _communication_api:
:noindex:

Communication API
^^^^^^^^^^^^^^^^^
Expand All @@ -285,6 +297,7 @@ should involve.
**Helper structures**

.. data:: smp.CommGroup
:noindex:

An ``enum`` that takes the values
``CommGroup.WORLD``, ``CommGroup.MP_GROUP``, and ``CommGroup.DP_GROUP``.
Expand All @@ -303,6 +316,7 @@ should involve.
themselves.

.. data:: smp.RankType
:noindex:

An ``enum`` that takes the values
``RankType.WORLD_RANK``, ``RankType.MP_RANK``, and ``RankType.DP_RANK``.
Expand All @@ -318,6 +332,7 @@ should involve.
**Communication primitives:**

.. function:: smp.broadcast(obj, group)
:noindex:

Sends the object to all processes in the
group. The receiving process must call ``smp.recv_from`` to receive the
Expand Down Expand Up @@ -350,6 +365,7 @@ should involve.
    smp.recv_from(0, rank_type=smp.RankType.WORLD_RANK)

.. function:: smp.send(obj, dest_rank, rank_type)
:noindex:

Sends the object ``obj`` to
``dest_rank``, which is of a type specified by ``rank_type``.
Expand All @@ -373,6 +389,7 @@ should involve.
``recv_from`` call.

.. function:: smp.recv_from(src_rank, rank_type)
:noindex:

Receive an object from a peer process. Can be used with a matching
``smp.send`` or a ``smp.broadcast`` call.
Expand All @@ -398,6 +415,7 @@ should involve.
``broadcast`` call, and the object is received.

.. function:: smp.allgather(obj, group)
:noindex:

A collective call that gathers all the
submitted objects across all ranks in the specified ``group``. Returns a
Expand Down Expand Up @@ -431,6 +449,7 @@ should involve.
    out = smp.allgather(obj2, smp.CommGroup.MP_GROUP)  # returns [obj1, obj2]

.. function:: smp.barrier(group=smp.WORLD)
:noindex:

A statement that hangs until all
processes in the specified group reach the barrier statement, similar to
Expand All @@ -452,12 +471,14 @@ should involve.
processes outside that ``mp_group``.

.. function:: smp.dp_barrier()
:noindex:

Same as passing ``smp.DP_GROUP``\ to ``smp.barrier()``.
Waits for the processes in the same \ ``dp_group`` as
the current process to reach the same point in execution.

.. function:: smp.mp_barrier()
:noindex:

Same as passing ``smp.MP_GROUP`` to
``smp.barrier()``. Waits for the processes in the same ``mp_group`` as
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ This API document assumes you use the following import statements in your traini
to learn how to use the following API in your PyTorch training script.

.. class:: smp.DistributedModel
:noindex:

A sub-class of ``torch.nn.Module`` which specifies the model to be
partitioned. Accepts a ``torch.nn.Module`` object ``module`` which is
Expand Down Expand Up @@ -157,6 +158,7 @@ This API document assumes you use the following import statements in your traini
**Methods**

.. function:: backward(tensors, grad_tensors)
:noindex:

Triggers a distributed backward
pass across model partitions. Example usage provided in the previous
Expand All @@ -165,42 +167,49 @@ This API document assumes you use the following import statements in your traini
``retain_grad`` and ``create_graph``  flags are not supported.

.. function:: local_buffers( )
:noindex:

Returns an iterator over buffers for the modules in
the partitioned model that have been assigned to the current process.

.. function:: local_named_buffers( )
:noindex:

Returns an iterator over buffers for the
modules in the partitioned model that have been assigned to the current
process. This yields both the name of the buffer as well as the buffer
itself.

.. function:: local_parameters( )
:noindex:

Returns an iterator over parameters for the
modules in the partitioned model that have been assigned to the current
process.

.. function:: local_named_parameters( )
:noindex:

Returns an iterator over parameters for
the modules in the partitioned model that have been assigned to the
current process. This yields both the name of the parameter as well as
the parameter itself.

.. function:: local_modules( )
:noindex:

Returns an iterator over the modules in the
partitioned model that have been assigned to the current process.

.. function:: local_named_modules( )
:noindex:

Returns an iterator over the modules in the
partitioned model that have been assigned to the current process. This
yields both the name of the module as well as the module itself.

.. function:: local_state_dict( )
:noindex:

Returns the ``state_dict`` that contains local
parameters that belong to the current \ ``mp_rank``. This ``state_dict``
Expand All @@ -210,34 +219,39 @@ This API document assumes you use the following import statements in your traini
partition, or to the entire model.

.. function:: state_dict( )
:noindex:

Returns the ``state_dict`` that contains parameters
for the entire model. It first collects the \ ``local_state_dict``  and
gathers and merges the \ ``local_state_dict`` from all ``mp_rank``\ s to
create a full ``state_dict``.

.. function:: load_state_dict( )
:noindex:

Same as the ``torch.module.load_state_dict()`` ,
except: It first gathers and merges the ``state_dict``\ s across
``mp_rank``\ s, if they are partial. The actual loading happens after the
model partition so that each rank knows its local parameters.

.. function:: register_post_partition_hook(hook)
:noindex:

Registers a callable ``hook`` to
be executed after the model is partitioned. This is useful in situations
where an operation needs to be executed after the model partition during
the first call to ``smp.step``, but before the actual execution of the
the first call to ``smp.step`` but before the actual execution of the
first forward pass. Returns a ``RemovableHandle`` object ``handle``,
which can be used to remove the hook by calling ``handle.remove()``.

.. function:: cpu( )
.. function:: cpu( )
:noindex:

Allgathers parameters and buffers across all ``mp_rank``\ s and moves them
to the CPU.

.. class:: smp.DistributedOptimizer
:noindex:

**Parameters**
- ``optimizer``
Expand All @@ -246,13 +260,15 @@ This API document assumes you use the following import statements in your traini
returns ``optimizer`` with the following methods overridden:

.. function:: state_dict( )
:noindex:

Returns the ``state_dict`` that contains optimizer state for the entire model.
It first collects the ``local_state_dict`` and gathers and merges
the ``local_state_dict`` from all ``mp_rank``s to create a full
``state_dict``.

.. function:: load_state_dict( )
:noindex:

Same as the ``torch.optimizer.load_state_dict()`` , except:

Expand All @@ -262,6 +278,7 @@ This API document assumes you use the following import statements in your traini
rank knows its local parameters.

.. function:: local_state_dict( )
:noindex:

Returns the ``state_dict`` that contains the
local optimizer state that belongs to the current \ ``mp_rank``. This
Expand Down Expand Up @@ -308,70 +325,79 @@ This API document assumes you use the following import statements in your traini
        self.child3 = Child3()                # child3 on default_partition

.. function:: smp.get_world_process_group( )
:noindex:

Returns a ``torch.distributed`` ``ProcessGroup`` that consists of all
processes, which can be used with the ``torch.distributed`` API.
Requires ``"ddp": True`` in SageMaker Python SDK parameters.

.. function:: smp.get_mp_process_group( )
:noindex:

Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the
processes in the ``MP_GROUP`` which contains the current process, which
can be used with the \ ``torch.distributed`` API. Requires
``"ddp": True`` in SageMaker Python SDK parameters.

.. function:: smp.get_dp_process_group( )
:noindex:

Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the
processes in the ``DP_GROUP`` which contains the current process, which
can be used with the \ ``torch.distributed`` API. Requires
``"ddp": True`` in SageMaker Python SDK parameters.

.. function:: smp.is_initialized( )
:noindex:

Returns ``True`` if ``smp.init`` has already been called for the
process, and ``False`` otherwise.

.. function::smp.is_tracing( )
:noindex:

Returns ``True`` if the current process is running the tracing step, and
``False`` otherwise.

.. data:: smp.nn.FusedLayerNorm
:noindex:

`Apex Fused Layer Norm <https://nvidia.github.io/apex/layernorm.html>`__ is currently not
supported by the library. ``smp.nn.FusedLayerNorm`` replaces ``apex``
``FusedLayerNorm`` and provides the same functionality. This requires
``apex`` to be installed on the system.

.. data:: smp.optimizers.FusedNovoGrad

:noindex:

`Fused Novo Grad optimizer <https://nvidia.github.io/apex/optimizers.html#apex.optimizers.FusedNovoGrad>`__ is
currently not supported by the library. ``smp.optimizers.FusedNovoGrad`` replaces ``apex`` ``FusedNovoGrad``
optimizer and provides the same functionality. This requires ``apex`` to
be installed on the system.

.. data:: smp.optimizers.FusedLamb

:noindex:

`FusedLamb optimizer <https://nvidia.github.io/apex/optimizers.html#apex.optimizers.FusedLAMB>`__
currently doesn’t work with the library. ``smp.optimizers.FusedLamb`` replaces
``apex`` ``FusedLamb`` optimizer and provides the same functionality.
This requires ``apex`` to be installed on the system.

.. data:: smp.amp.GradScaler
:noindex:

`Torch AMP Gradscaler <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler>`__
currently doesn’t work with the library. ``smp.amp.GradScaler`` replaces
``torch.amp.GradScaler`` and provides the same functionality.

.. _pytorch_saving_loading:
:noindex:

APIs for Saving and Loading
^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. function:: smp.save( )
:noindex:

Saves an object. This operation is similar to ``torch.save()``, except
it has an additional keyword argument, ``partial``, and accepts only
Expand All @@ -394,6 +420,7 @@ APIs for Saving and Loading
override the defaultprotocol.

.. function:: smp.load( )
:noindex:

Loads an object saved with ``smp.save()`` from a file.

Expand All @@ -418,6 +445,7 @@ APIs for Saving and Loading
Should be used when loading a model trained with the library.

.. _pytorch_saving_loading_instructions:
:noindex:

General Instruction For Saving and Loading
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
Loading