Skip to content

Commit 2292b30

Browse files
Add API enhancements for SMP (#2048)
* Add API enhancements for SMP * Update doc/api/training/smd_model_parallel_common_api.rst Co-authored-by: Aaron Markham <[email protected]> * Update doc/api/training/smd_model_parallel_pytorch.rst Co-authored-by: Aaron Markham <[email protected]> * Update doc/api/training/smd_model_parallel_common_api.rst Co-authored-by: Aaron Markham <[email protected]> Co-authored-by: Aaron Markham <[email protected]>
1 parent ffa9dc3 commit 2292b30

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

doc/api/training/smd_model_parallel_common_api.rst

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,37 @@ The following APIs are common across all frameworks.
5757
versions of the tensor across different microbatches
5858
(see ``StepOutput`` entry for more information).
5959

60+
The argument to ``smp.step`` decorated function should either be a tensor
61+
or an instance of list, tuple, dict or set for it to be split across
62+
microbatches. If your object doesn't fall into this category, you can make
63+
the library split your object, by implementing ``smp_slice`` method.
64+
65+
Below is an example of how to use it with PyTorch.
66+
67+
.. code:: python
68+
69+
class CustomType:
70+
def __init__(self, tensor):
71+
self.data = tensor
72+
73+
# The library will call this to invoke slicing on the object passing in total microbatches (num_mb)
74+
# and the current microbatch index (mb).
75+
def smp_slice(self, num_mb, mb, axis):
76+
dim_size = list(self.data.size())[axis]
77+
78+
split_size = dim_size // num_mb
79+
sliced_tensor = self.data.narrow(axis, mb * split_size, split_size)
80+
return CustomType(sliced_tensor, self.other)
81+
82+
custom_obj = CustomType(torch.ones(4,))
83+
84+
@smp.step()
85+
def step(custom_obj):
86+
loss = model(custom_obj)
87+
model.backward(loss)
88+
return loss
89+
90+
6091
**Important:** ``smp.step`` splits the batch into microbatches, and
6192
executes everything inside the decorated function once per microbatch.
6293
This might affect the behavior of batch normalization, any operation

doc/api/training/smd_model_parallel_pytorch.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ This API document assumes you use the following import statements in your traini
128128
computation. \ ``bucket_cap_mb``\ controls the bucket size in MegaBytes
129129
(MB).
130130

131+
- ``trace_memory_usage`` (default: False): When set to True, the library attempts
132+
to measure memory usage per module during tracing. If this is disabled,
133+
memory usage will be estimated through the sizes of tensors returned from
134+
the module.
135+
131136
**Properties**
132137

133138
- ``partitioned``: Is ``True`` if the model is partitioned, ``False``
@@ -215,6 +220,11 @@ This API document assumes you use the following import statements in your traini
215220
first forward pass. Returns a ``RemovableHandle`` object ``handle``,
216221
which can be used to remove the hook by calling ``handle.remove()``.
217222

223+
.. function:: cpu( )
224+
225+
Allgathers parameters and buffers across all ``mp_rank``\ s and moves them
226+
to the CPU.
227+
218228
.. class:: smp.DistributedOptimizer
219229

220230
**Parameters**

0 commit comments

Comments
 (0)