Skip to content

Commit ebabe56

Browse files
ifsheldontchatonawaelchlicarmocca
authored
Ensure accelerator is valid if running interactively (#5970)
Co-authored-by: chaton <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 863a70c commit ebabe56

File tree

5 files changed

+48
-5
lines changed

5 files changed

+48
-5
lines changed

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,9 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
517517
rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.')
518518
self._distrib_type = None
519519

520+
# finished configuring self._distrib_type, check ipython environment
521+
self.check_interactive_compatibility()
522+
520523
# for DDP overwrite nb processes by requested GPUs
521524
if (
522525
self._device_type == DeviceType.GPU
@@ -558,6 +561,19 @@ def _set_horovod_backend(self):
558561
else:
559562
self.num_processes = hvd.local_size()
560563

564+
def check_interactive_compatibility(self):
565+
"""
566+
Raises a `MisconfigurationException` if the accelerator and/or plugin
567+
is not compatible with an interactive environment
568+
"""
569+
from pytorch_lightning.utilities import _IS_INTERACTIVE
570+
if _IS_INTERACTIVE and self._distrib_type is not None and not self._distrib_type.is_interactive_compatible():
571+
raise MisconfigurationException(
572+
f"Selected distributed backend {self._distrib_type} is not compatible with an interactive"
573+
" environment. Run your code as a script, or choose one of the compatible backends:"
574+
f" {', '.join(DistributedType.interactive_compatible_types())}"
575+
)
576+
561577
def check_horovod(self):
562578
"""Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod."""
563579
if not _HOROVOD_AVAILABLE:

pytorch_lightning/utilities/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""General utilities"""
15-
1615
import numpy
1716

1817
from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401
@@ -33,6 +32,7 @@
3332
_HOROVOD_AVAILABLE,
3433
_HYDRA_AVAILABLE,
3534
_HYDRA_EXPERIMENTAL_AVAILABLE,
35+
_IS_INTERACTIVE,
3636
_module_available,
3737
_NATIVE_AMP_AVAILABLE,
3838
_OMEGACONF_AVAILABLE,

pytorch_lightning/utilities/enums.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
# limitations under the License.
1414
"""Enumerated utilities"""
1515
from enum import Enum
16-
from typing import Union
16+
from typing import List, Optional, Union
1717

1818

1919
class LightningEnum(str, Enum):
2020
""" Type of any enumerator with allowed comparison to string invariant to cases. """
2121

2222
@classmethod
23-
def from_str(cls, value: str) -> 'LightningEnum':
23+
def from_str(cls, value: str) -> Optional['LightningEnum']:
2424
statuses = [status for status in dir(cls) if not status.startswith('_')]
2525
for st in statuses:
2626
if st.lower() == value.lower():
@@ -31,7 +31,7 @@ def __eq__(self, other: Union[str, Enum]) -> bool:
3131
other = other.value if isinstance(other, Enum) else str(other)
3232
return self.value.lower() == other.lower()
3333

34-
def __hash__(self):
34+
def __hash__(self) -> int:
3535
# re-enable hashtable so it can be used as a dict key or in a set
3636
# example: set(LightningEnum)
3737
return hash(self.name)
@@ -58,6 +58,16 @@ class DistributedType(LightningEnum):
5858
>>> DistributedType.DDP2 in ('ddp2', )
5959
True
6060
"""
61+
62+
@staticmethod
63+
def interactive_compatible_types() -> List['DistributedType']:
64+
"""Returns a list containing interactive compatible DistributeTypes"""
65+
return [DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN]
66+
67+
def is_interactive_compatible(self) -> bool:
68+
"""Returns whether self is interactive compatible"""
69+
return self in DistributedType.interactive_compatible_types()
70+
6171
DP = 'dp'
6272
DDP = 'ddp'
6373
DDP2 = 'ddp2'

pytorch_lightning/utilities/imports.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""General utilities"""
1515
import operator
1616
import platform
17+
import sys
1718
from distutils.version import LooseVersion
1819
from importlib.util import find_spec
1920

@@ -49,10 +50,11 @@ def _compare_version(package: str, op, version) -> bool:
4950

5051

5152
_IS_WINDOWS = platform.system() == "Windows"
53+
_IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765
5254
_TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0")
5355
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
5456
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
55-
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none'])
57+
5658
_APEX_AVAILABLE = _module_available("apex.amp")
5759
_BOLTS_AVAILABLE = _module_available('pl_bolts')
5860
_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed')
@@ -65,6 +67,7 @@ def _compare_version(package: str, op, version) -> bool:
6567
_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
6668
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
6769
_RPC_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.rpc')
70+
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none'])
6871
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
6972
_TORCHVISION_AVAILABLE = _module_available('torchvision')
7073
_XLA_AVAILABLE = _module_available("torch_xla")

tests/accelerators/test_accelerator_connector.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
SingleDevicePlugin,
3333
)
3434
from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment
35+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3536
from tests.helpers.boring_model import BoringModel
3637

3738

@@ -387,6 +388,19 @@ def on_fit_start(self, trainer, pl_module):
387388
trainer.fit(model)
388389

389390

391+
@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True)
392+
@mock.patch('torch.cuda.device_count', return_value=2)
393+
def test_ipython_incompatible_backend_error(*_):
394+
with pytest.raises(MisconfigurationException, match="backend ddp is not compatible"):
395+
Trainer(accelerator="ddp", gpus=2)
396+
397+
with pytest.raises(MisconfigurationException, match="backend ddp is not compatible"):
398+
Trainer(accelerator="ddp_cpu", num_processes=2)
399+
400+
with pytest.raises(MisconfigurationException, match="backend ddp2 is not compatible"):
401+
Trainer(accelerator="ddp2", gpus=2)
402+
403+
390404
@pytest.mark.parametrize(
391405
["accelerator", "plugin"],
392406
[('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')],

0 commit comments

Comments
 (0)