Skip to content

Commit 0455231

Browse files
SeanNarencarmocca
authored andcommitted
[Hot Fix] Give priority to plugins to set distributed mode, and then accelerator (Lightning-AI#6089)
* Give priority to plugins to set distributed mode, and then accelerator * Add CHANGELOG.md * Update CHANGELOG.md * Remove very scary line * Ensure we set cluster environment after slurm configured if necessary * Simplify the fix with a reset Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 7c323ba commit 0455231

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080))
1313

1414

15+
- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089))
16+
17+
1518
## [1.2.0] - 2021-02-18
1619

1720
### Added

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ def handle_given_plugins(
163163

164164
for plug in plugins:
165165
if isinstance(plug, str):
166+
# Reset the distributed type as the user has overridden training type
167+
# via the plugins argument
168+
self._distrib_type = None
166169
self.set_distributed_mode(plug)
167170

168171
elif isinstance(plug, TrainingTypePlugin):
@@ -196,7 +199,6 @@ def handle_given_plugins(
196199
)
197200

198201
self._training_type_plugin = training_type
199-
self._training_type_plugin = self.training_type_plugin
200202
self._precision_plugin = precision
201203
self._cluster_environment = cluster_environment or self.select_cluster_environment()
202204

tests/accelerators/test_accelerator_connector.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@
2323
from pytorch_lightning.accelerators.cpu import CPUAccelerator
2424
from pytorch_lightning.accelerators.gpu import GPUAccelerator
2525
from pytorch_lightning.callbacks import Callback
26-
from pytorch_lightning.plugins import DDP2Plugin, DDPPlugin, DDPSpawnPlugin, PrecisionPlugin, SingleDevicePlugin
26+
from pytorch_lightning.plugins import (
27+
DDP2Plugin,
28+
DDPPlugin,
29+
DDPShardedPlugin,
30+
DDPSpawnPlugin,
31+
PrecisionPlugin,
32+
SingleDevicePlugin,
33+
)
2734
from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment
2835
from tests.helpers.boring_model import BoringModel
2936

@@ -378,3 +385,18 @@ def on_fit_start(self, trainer, pl_module):
378385

379386
with pytest.raises(SystemExit):
380387
trainer.fit(model)
388+
389+
390+
@pytest.mark.parametrize(
391+
["accelerator", "plugin"],
392+
[('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')],
393+
)
394+
def test_plugin_accelerator_choice(accelerator, plugin):
395+
"""
396+
Ensure that when a plugin and accelerator is passed in, that the plugin takes precedent.
397+
"""
398+
trainer = Trainer(accelerator=accelerator, plugins=plugin, num_processes=2)
399+
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)
400+
401+
trainer = Trainer(plugins=plugin, num_processes=2)
402+
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)

0 commit comments

Comments
 (0)