Skip to content

Commit 8001987

Browse files
authored
[fix] Ensure we check deepspeed/sharded in multinode DDP (#6297)
* Ensure we check deepspeed/sharded in multinode * Add CHANGELOG.md * Add CHANGELOG.md * Drop mock, use actual multi-gpu node
1 parent b46d221 commit 8001987

File tree

3 files changed

+34
-5
lines changed

3 files changed

+34
-5
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8686
- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))
8787

8888

89+
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)
90+
91+
8992
## [1.2.1] - 2021-02-23
9093

9194
### Fixed

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -536,12 +536,12 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
536536
if self.distributed_backend == "horovod":
537537
self._set_horovod_backend()
538538

539-
# throw error to force user ddp or ddp2 choice
540-
_ddp = (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
541-
if (self.num_nodes > 1 and self._distrib_type not in _ddp):
539+
using_valid_distributed = self.use_ddp or self.use_ddp2
540+
if self.num_nodes > 1 and not using_valid_distributed:
541+
# throw error to force user to choose a supported distributed type such as ddp or ddp2
542542
raise MisconfigurationException(
543-
'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
544-
'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`'
543+
'Your chosen distributed type does not support num_nodes > 1. '
544+
'Please set accelerator=ddp or accelerator=ddp2.'
545545
)
546546

547547
rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self._device_type == DeviceType.GPU}')

tests/accelerators/test_accelerator_connector.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@
2828
DDPPlugin,
2929
DDPShardedPlugin,
3030
DDPSpawnPlugin,
31+
DDPSpawnShardedPlugin,
32+
DeepSpeedPlugin,
3133
PrecisionPlugin,
3234
SingleDevicePlugin,
3335
)
3436
from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment
37+
from pytorch_lightning.utilities import _DEEPSPEED_AVAILABLE
3538
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3639
from tests.helpers.boring_model import BoringModel
3740
from tests.helpers.runif import RunIf
@@ -415,3 +418,26 @@ def test_plugin_accelerator_choice(accelerator, plugin):
415418

416419
trainer = Trainer(plugins=plugin, num_processes=2)
417420
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)
421+
422+
423+
@pytest.mark.parametrize(["accelerator", "plugin"], [
424+
('ddp', DDPPlugin),
425+
('ddp_spawn', DDPSpawnPlugin),
426+
('ddp_sharded', DDPShardedPlugin),
427+
('ddp_sharded_spawn', DDPSpawnShardedPlugin),
428+
pytest.param(
429+
'deepspeed',
430+
DeepSpeedPlugin,
431+
marks=pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
432+
),
433+
])
434+
@mock.patch('torch.cuda.is_available', return_value=True)
435+
@mock.patch('torch.cuda.device_count', return_value=2)
436+
def test_accelerator_choice_multi_node_gpu(mock_is_available, mock_device_count, accelerator, plugin, tmpdir):
437+
trainer = Trainer(
438+
accelerator=accelerator,
439+
default_root_dir=tmpdir,
440+
num_nodes=2,
441+
gpus=2,
442+
)
443+
assert isinstance(trainer.training_type_plugin, plugin)

0 commit comments

Comments
 (0)