Skip to content

Commit 432e563

Browse files
SeanNarencarmocca
andauthored
Expose DeepSpeed FP16 parameters due to loss instability (#6115)
* Expose deepspeed config parameters to init function due to instability in parameters * See if tests can run on normal CI, without special tests * Add changelog * Update pytorch_lightning/plugins/training_type/deepspeed.py Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 3b0e4e0 commit 432e563

File tree

4 files changed

+92
-15
lines changed

4 files changed

+92
-15
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3333
- Move lightning module to correct device type when using LightningDistributedWrapper ([#6070](https://github.com/PyTorchLightning/pytorch-lightning/pull/6070)
3434

3535

36+
- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)
37+
38+
3639
## [1.2.0] - 2021-02-18
3740

3841
### Added

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ def __init__(
7979
num_nodes: int = 1,
8080
parallel_devices: Optional[List[torch.device]] = None,
8181
cluster_environment: Optional[ClusterEnvironment] = None,
82+
loss_scale: float = 0,
83+
initial_scale_power: int = 32,
84+
loss_scale_window: int = 1000,
85+
hysteresis: int = 2,
86+
min_loss_scale: int = 1
8287
) -> None:
8388
"""
8489
@@ -127,6 +132,18 @@ def __init__(
127132
128133
logging_level: Set logging level for deepspeed. (Default: ``logging.WARN``)
129134
135+
loss_scale: Loss scaling value for FP16 training.
136+
0.0 results in dynamic loss scaling, otherwise static (Default: 0)
137+
138+
initial_scale_power: Power of the initial dynamic loss scale value. Loss scale is computed
139+
by ``2^initial_scale_power`` (Default: 32)
140+
141+
loss_scale_window: Window in which to raise/lower the dynamic FP16 loss scaling value (Default: 1000)
142+
143+
hysteresis: FP16 Delay shift in Dynamic Loss scaling (Default: 2)
144+
145+
min_loss_scale: The minimum FP16 dynamic loss scaling value (Default: 1000)
146+
130147
"""
131148
if not _DEEPSPEED_AVAILABLE:
132149
raise MisconfigurationException(
@@ -154,6 +171,13 @@ def __init__(
154171
self._config_initialized = False
155172
deepspeed.utils.logging.logger.setLevel(logging_level)
156173

174+
# default FP16 parameters.
175+
self.loss_scale = loss_scale
176+
self.initial_scale_power = initial_scale_power
177+
self.loss_scale_window = loss_scale_window
178+
self.hysteresis = hysteresis
179+
self.min_loss_scale = min_loss_scale
180+
157181
def _load_config(self, config):
158182
if config is None and self.DEEPSPEED_ENV_VAR in os.environ:
159183
rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
@@ -297,9 +321,19 @@ def _format_precision_config(self):
297321
amp_level = self.lightning_module.trainer.accelerator_connector.amp_level
298322
precision = self.lightning_module.trainer.accelerator_connector.precision
299323
if precision == 16:
300-
if "amp" not in self.config and amp_type == AMPType.NATIVE:
301-
self.config["fp16"] = {"enabled": True}
302-
elif "apex" not in self.config and amp_type == AMPType.APEX:
324+
if "fp16" not in self.config and amp_type == AMPType.NATIVE:
325+
# FP16 is a DeepSpeed standalone AMP implementation
326+
rank_zero_info("Enabling DeepSpeed FP16.")
327+
self.config["fp16"] = {
328+
"enabled": True,
329+
"loss_scale": self.loss_scale,
330+
"initial_scale_power": self.initial_scale_power,
331+
"loss_scale_window": self.loss_scale_window,
332+
"hysteresis": self.hysteresis,
333+
"min_loss_scale": self.min_loss_scale
334+
}
335+
elif "amp" not in self.config and amp_type == AMPType.APEX:
336+
rank_zero_only("Enabling DeepSpeed APEX Implementation.")
303337
self.config["amp"] = {
304338
"enabled": True,
305339
"opt_level": amp_level,

tests/plugins/test_deepspeed_plugin.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,6 @@ def test_invalid_deepspeed_defaults_no_precision(tmpdir):
211211

212212
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
213213
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
214-
@pytest.mark.skipif(
215-
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
216-
)
217214
def test_warn_deepspeed_override_backward(tmpdir):
218215
"""
219216
Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning.
@@ -232,9 +229,6 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args
232229

233230
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
234231
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
235-
@pytest.mark.skipif(
236-
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
237-
)
238232
def test_deepspeed_run_configure_optimizers(tmpdir):
239233
"""
240234
Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation),
@@ -268,9 +262,6 @@ def on_train_start(self) -> None:
268262

269263
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
270264
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
271-
@pytest.mark.skipif(
272-
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
273-
)
274265
def test_deepspeed_config(tmpdir, deepspeed_zero_config):
275266
"""
276267
Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers
@@ -304,6 +295,58 @@ def on_train_start(self) -> None:
304295
_assert_save_model_is_equal(model, tmpdir, trainer)
305296

306297

298+
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
299+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
300+
def test_deepspeed_custom_precision_params(tmpdir):
301+
"""
302+
Ensure if we modify the FP16 parameters via the DeepSpeedPlugin, the deepspeed config contains these changes.
303+
"""
304+
305+
class TestModel(BoringModel):
306+
307+
def on_train_start(self) -> None:
308+
assert self.trainer.training_type_plugin.config['fp16']['loss_scale'] == 10
309+
assert self.trainer.training_type_plugin.config['fp16']['initial_scale_power'] == 10
310+
assert self.trainer.training_type_plugin.config['fp16']['loss_scale_window'] == 10
311+
assert self.trainer.training_type_plugin.config['fp16']['hysteresis'] == 10
312+
assert self.trainer.training_type_plugin.config['fp16']['min_loss_scale'] == 10
313+
raise SystemExit()
314+
315+
model = TestModel()
316+
trainer = Trainer(
317+
plugins=[
318+
DeepSpeedPlugin(
319+
loss_scale=10, initial_scale_power=10, loss_scale_window=10, hysteresis=10, min_loss_scale=10
320+
)
321+
],
322+
precision=16,
323+
gpus=1
324+
)
325+
with pytest.raises(SystemExit):
326+
trainer.fit(model)
327+
328+
329+
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
330+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
331+
def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_config):
332+
"""
333+
Ensure if we use a config and turn off cpu_offload, that this is set to False within the config.
334+
"""
335+
336+
deepspeed_zero_config['zero_optimization']['cpu_offload'] = False
337+
338+
class TestModel(BoringModel):
339+
340+
def on_train_start(self) -> None:
341+
assert self.trainer.training_type_plugin.config['zero_optimization']['cpu_offload'] is False
342+
raise SystemExit()
343+
344+
model = TestModel()
345+
trainer = Trainer(plugins=[DeepSpeedPlugin(config=deepspeed_zero_config)], precision=16, gpus=1)
346+
with pytest.raises(SystemExit):
347+
trainer.fit(model)
348+
349+
307350
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
308351
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
309352
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")

tests/special_tests.sh

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@ export PL_RUNNING_SPECIAL_TESTS=1
1717
DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no"
1818
python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp
1919
python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp
20-
python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_warn_deepspeed_override_backward
21-
python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_run_configure_optimizers
22-
python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_config
2320
python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_multigpu
2421
python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp
2522
python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual

0 commit comments

Comments
 (0)