Skip to content

Commit 18a4806

Browse files
Updated release notes and API doc for smd model parallel 1.3.1 (#2267)
Co-authored-by: Talia <[email protected]>
1 parent 85321d3 commit 18a4806

File tree

2 files changed

+47
-9
lines changed

2 files changed

+47
-9
lines changed

doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,33 @@
1+
# Sagemaker Distributed Model Parallel 1.3.1 Release Notes
2+
3+
- New Features
4+
- Bug Fixes
5+
- Known Issues
6+
7+
## New Features
8+
9+
### TensorFlow
10+
11+
- Exposes a new decorator ``register_post_partition_hook``. This allows invoking the decorated methods just after model partition but before executing the first step. For example loading a checkpoint. Refer to the [SageMaker distributed model parallel API documentation](https://sagemaker.readthedocs.io/en/stable/api/training/smp_versions/latest/smd_model_parallel_tensorflow.html) for more information.
12+
13+
## Bug Fixes
14+
15+
### PyTorch
16+
17+
- Improved memory efficiency when using active microbatches by clearing activations at end of each microbatch.
18+
19+
### TensorFlow
20+
21+
- Fixed issue that caused hangs when training some models with XLA enabled.
22+
23+
## Known Issues
24+
25+
### PyTorch
26+
27+
- A crash was observed when ``optimizer.step()`` was called for certain optimizers such as AdaDelta, when the partition on which this method was called has no local parameters assigned to it after partitioning. This is due to a bug in PyTorch which [has since been fixed](https://github.com/pytorch/pytorch/pull/52944). Till that makes its way to the next release of PyTorch, only call ``optimizer.step()`` on processes which have at least one local parameter. This can be checked like this ``len(list(model.local_parameters())) > 0``.
28+
29+
- A performance regression still exists when training on SMP with PyTorch 1.7.1 compared to 1.6. The rootcause was found to be the slowdown in performance of `.grad` method calls in PyTorch 1.7.1 compared to 1.6. See the related discussion: https://github.com/pytorch/pytorch/issues/50636. This issue does not exist with PyTorch 1.8.
30+
131
# Sagemaker Distributed Model Parallel 1.3.0 Release Notes
232

333
- New Features

doc/api/training/smp_versions/latest/smd_model_parallel_tensorflow.rst

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,21 @@ TensorFlow API
8383
    with smp.partition(3):
8484
        z = tf.reduce_sum(y)             # placed in partition 3
8585
86-
86+
87+
.. function:: register_post_partition_hook(hook)
88+
89+
Registers a callable ``hook`` to
90+
be executed after the model is partitioned. This is useful in situations
91+
where an operation needs to be executed after the model partition during
92+
the first call to ``smp.step``, but before the actual execution of the
93+
first forward pass.
94+
95+
.. code:: python
96+
97+
@smp.register_post_partition_hook
98+
def test_eager():
99+
# All statements here will be executed right after partition but before the first forward pass
100+
tf.print("Entered hook through eager context")
87101
88102
.. class:: smp.CheckpointManager
89103

@@ -102,13 +116,6 @@ TensorFlow API
102116
                      max_to_keep=None,
103117
                      checkpoint_name="ckpt")
104118
105-
106-
**Important:** ``smp.CheckpointManager.restore()`` must be called after
107-
the first training step. This is because the first call of the
108-
``smp.step`` function constructs and partitions the model, which must
109-
take place before the checkpoint restore. Calling it before the first
110-
``smp.step`` call might result in hangs or unexpected behavior.
111-
112119
**Parameters**
113120

114121
- ``checkpoint``: A `tf.train.Checkpoint
@@ -154,7 +161,8 @@ TensorFlow API
154161
.. code:: python
155162
156163
for step, inputs in enumerate(train_ds):
157-
    if step == 1:                    # NOTE: restore occurs on the second step
164+
    if step == 0:
158165
        ckpt_manager.restore()
159166
    loss = train_step(inputs)
160167
168+

0 commit comments

Comments
 (0)