Skip to content

Commit 97a81c3

Browse files
SeanNarencarmocca
andauthored
[Hot Fix] Give priority to plugins to set distributed mode, and then accelerator (#6089)
* Give priority to plugins to set distributed mode, and then accelerator * Add CHANGELOG.md * Update CHANGELOG.md * Remove very scary line * Ensure we set cluster environment after slurm configured if necessary * Simplify the fix with a reset Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 3bdc067 commit 97a81c3

File tree

3 files changed

+28
-2
lines changed

3 files changed

+28
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2727
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)
2828

2929

30+
- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089))
31+
3032

3133
## [1.2.0] - 2021-02-18
3234

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ def handle_given_plugins(
163163

164164
for plug in plugins:
165165
if isinstance(plug, str):
166+
# Reset the distributed type as the user has overridden training type
167+
# via the plugins argument
168+
self._distrib_type = None
166169
self.set_distributed_mode(plug)
167170

168171
elif isinstance(plug, TrainingTypePlugin):
@@ -196,7 +199,6 @@ def handle_given_plugins(
196199
)
197200

198201
self._training_type_plugin = training_type
199-
self._training_type_plugin = self.training_type_plugin
200202
self._precision_plugin = precision
201203
self._cluster_environment = cluster_environment or self.select_cluster_environment()
202204

tests/accelerators/test_accelerator_connector.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@
2323
from pytorch_lightning.accelerators.cpu import CPUAccelerator
2424
from pytorch_lightning.accelerators.gpu import GPUAccelerator
2525
from pytorch_lightning.callbacks import Callback
26-
from pytorch_lightning.plugins import DDP2Plugin, DDPPlugin, DDPSpawnPlugin, PrecisionPlugin, SingleDevicePlugin
26+
from pytorch_lightning.plugins import (
27+
DDP2Plugin,
28+
DDPPlugin,
29+
DDPShardedPlugin,
30+
DDPSpawnPlugin,
31+
PrecisionPlugin,
32+
SingleDevicePlugin,
33+
)
2734
from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment
2835
from tests.helpers.boring_model import BoringModel
2936

@@ -378,3 +385,18 @@ def on_fit_start(self, trainer, pl_module):
378385

379386
with pytest.raises(SystemExit):
380387
trainer.fit(model)
388+
389+
390+
@pytest.mark.parametrize(
391+
["accelerator", "plugin"],
392+
[('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')],
393+
)
394+
def test_plugin_accelerator_choice(accelerator, plugin):
395+
"""
396+
Ensure that when a plugin and accelerator is passed in, that the plugin takes precedent.
397+
"""
398+
trainer = Trainer(accelerator=accelerator, plugins=plugin, num_processes=2)
399+
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)
400+
401+
trainer = Trainer(plugins=plugin, num_processes=2)
402+
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)

0 commit comments

Comments
 (0)