Skip to content

Commit ce6b96e

Browse files
committed
Merge branch 'master' into refactor/running_stage
2 parents cbdf2a8 + 97a81c3 commit ce6b96e

File tree

15 files changed

+182
-44
lines changed

15 files changed

+182
-44
lines changed

CHANGELOG.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,31 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8+
## [UnReleased] - 2021-MM-DD
9+
10+
### Added
11+
12+
13+
### Changed
14+
15+
16+
### Deprecated
17+
18+
19+
### Removed
20+
21+
22+
### Fixed
23+
24+
- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080))
25+
26+
27+
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)
28+
29+
30+
- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089))
31+
32+
833
## [1.2.0] - 2021-02-18
934

1035
### Added

azure-pipelines.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@ jobs:
6666
pip list
6767
displayName: 'Install dependencies'
6868
69-
- script: |
69+
- bash: |
7070
python tests/collect_env_details.py
71+
python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu >= 2, f'GPU: {mgpu}'"
7172
displayName: 'Env details'
7273
7374
- bash: |
@@ -76,7 +77,7 @@ jobs:
7677
ls -l legacy/checkpoints/
7778
displayName: 'Get legacy checkpoints'
7879
79-
- script: |
80+
- bash: |
8081
python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50
8182
displayName: 'Testing: standard'
8283
@@ -90,11 +91,11 @@ jobs:
9091
codecov --token=$(CODECOV_TOKEN) --flags=gpu,pytest --name="GPU-coverage" --env=linux,azure
9192
displayName: 'Statistics'
9293
93-
- script: |
94+
- bash: |
9495
python -m pytest benchmarks pl_examples -v --maxfail=2 --durations=0
9596
displayName: 'Testing: extended'
9697
97-
- script: |
98+
- bash: |
9899
python setup.py install --user --quiet
99100
bash pl_examples/run_ddp-example.sh
100101
pip uninstall -y pytorch-lightning

pytorch_lightning/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66

77
_this_year = time.strftime("%Y")
8-
__version__ = '1.2.0'
8+
__version__ = '1.3.0dev'
99
__author__ = 'William Falcon et al.'
1010
__author_email__ = '[email protected]'
1111
__license__ = 'Apache-2.0'

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,5 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
9191
@contextmanager
9292
def train_step_context(self) -> Generator[autocast, None, None]:
9393
"""Enable autocast context"""
94-
yield torch.cuda.amp.autocast()
94+
with torch.cuda.amp.autocast():
95+
yield

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,10 +278,22 @@ def model_to_device(self):
278278
torch.cuda.set_device(self.root_device)
279279
self.model.to(self.root_device)
280280

281-
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
282-
if isinstance(output, torch.Tensor):
283-
output = sync_ddp_if_available(output, group, reduce_op)
284-
return output
281+
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
282+
"""
283+
Reduces a tensor from several distributed processes to one aggregated tensor.
284+
285+
Args:
286+
tensor: the tensor to sync and reduce
287+
group: the process group to gather results from. Defaults to all processes (world)
288+
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
289+
Can also be a string 'sum' to calculate the sum during reduction.
290+
291+
Return:
292+
reduced value, except when the input was not a tensor the output remains is unchanged
293+
"""
294+
if isinstance(tensor, torch.Tensor):
295+
tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean"))
296+
return tensor
285297

286298
def training_step(self, *args, **kwargs):
287299
return self.model(*args, **kwargs)

pytorch_lightning/plugins/training_type/ddp2.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,26 @@ def setup(self, model):
2525
self.task_idx = self.cluster_environment.local_rank()
2626
# the difference to DDP is that we don't call children processes here
2727

28-
def reduce(self, output, *args, **kwargs):
29-
if isinstance(output, Result):
30-
output.dp_reduce()
28+
def reduce(self, tensor, *args, **kwargs):
29+
"""
30+
Reduces a tensor from all processes to one aggregated tensor.
31+
In DDP2, the reduction here is only across local devices within the node.
3132
32-
elif isinstance(output, torch.Tensor):
33-
output = output.mean()
33+
Args:
34+
tensor: the tensor to sync and reduce
35+
*args: ignored for DDP2
36+
**kwargs: ignored for DDP2
3437
35-
return output
38+
Return:
39+
reduced value, except when the input was not a tensor the output remains is unchanged
40+
"""
41+
if isinstance(tensor, Result):
42+
tensor.dp_reduce()
43+
44+
elif isinstance(tensor, torch.Tensor):
45+
tensor = tensor.mean()
46+
47+
return tensor
3648

3749
@property
3850
def root_device(self):

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,22 @@ def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, opti
256256
if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync:
257257
prepare_for_backward(self.model, closure_loss)
258258

259-
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
260-
if isinstance(output, torch.Tensor):
261-
output = sync_ddp_if_available(output, group, reduce_op)
262-
return output
259+
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
260+
"""
261+
Reduces a tensor from several distributed processes to one aggregated tensor.
262+
263+
Args:
264+
tensor: the tensor to sync and reduce
265+
group: the process group to gather results from. Defaults to all processes (world)
266+
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
267+
Can also be a string 'sum' to calculate the sum during reduction.
268+
269+
Return:
270+
reduced value, except when the input was not a tensor the output remains is unchanged
271+
"""
272+
if isinstance(tensor, torch.Tensor):
273+
tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean"))
274+
return tensor
263275

264276
def training_step(self, *args, **kwargs):
265277
return self.model(*args, **kwargs)

pytorch_lightning/plugins/training_type/dp.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,25 @@ def setup(self, model):
3131
model.to(self.root_device)
3232
self._model = DataParallel(LightningParallelModule(model), self.parallel_devices)
3333

34-
def reduce(self, output, *args, **kwargs):
35-
if isinstance(output, Result):
36-
output.dp_reduce()
34+
def reduce(self, tensor, *args, **kwargs):
35+
"""
36+
Reduces a tensor from all parallel processes to one aggregated tensor.
3737
38-
elif isinstance(output, torch.Tensor):
39-
output = output.mean()
38+
Args:
39+
tensor: the tensor to sync and reduce
40+
*args: ignored for DP
41+
**kwargs: ignored for DP
4042
41-
return output
43+
Return:
44+
reduced value, except when the input was not a tensor the output remains is unchanged
45+
"""
46+
if isinstance(tensor, Result):
47+
tensor.dp_reduce()
48+
49+
elif isinstance(tensor, torch.Tensor):
50+
tensor = tensor.mean()
51+
52+
return tensor
4253

4354
@property
4455
def root_device(self):

pytorch_lightning/plugins/training_type/horovod.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,23 +127,35 @@ def model_to_device(self):
127127
torch.cuda.set_device(self.root_device)
128128
self.model.to(self.root_device)
129129

130-
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
130+
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
131+
"""
132+
Reduces a tensor from several distributed processes to one aggregated tensor.
133+
134+
Args:
135+
tensor: the tensor to sync and reduce
136+
group: the process group to gather results from. Defaults to all processes (world)
137+
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
138+
Can also be a string 'sum' to calculate the sum during reduction.
139+
140+
Return:
141+
reduced value, except when the input was not a tensor the output remains is unchanged
142+
"""
131143
if group is not None:
132144
raise ValueError(
133145
"Horovod does not support allreduce using a subcommunicator at this time. "
134146
"Unset `group`."
135147
)
136148

137-
if reduce_op is None or reduce_op == "sum":
138-
reduce_op = hvd.Sum
139-
elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
149+
if reduce_op in (None, "avg", "mean"):
140150
reduce_op = hvd.Average
151+
elif reduce_op == "sum":
152+
reduce_op = hvd.Sum
141153
else:
142154
raise ValueError(f"unrecognized `reduce_op`: {reduce_op}")
143155

144156
# sync all processes before reduction
145157
hvd.join()
146-
return hvd.allreduce(output, op=reduce_op)
158+
return hvd.allreduce(tensor, op=reduce_op)
147159

148160
def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None):
149161
if group is not None:

pytorch_lightning/plugins/training_type/single_device.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,20 @@ def on_tpu(self) -> bool:
1919
def on_gpu(self) -> bool:
2020
return self.device.type == "cuda" and torch.cuda.is_available()
2121

22-
def reduce(self, output: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
23-
return output
22+
def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
23+
"""
24+
Reduces a tensor from several distributed processes to one aggregated tensor.
25+
As this plugin only operates with a single device, the reduction is simply the identity.
26+
27+
Args:
28+
tensor: the tensor to sync and reduce
29+
*args: ignored
30+
**kwargs: ignored
31+
32+
Return:
33+
the unmodified input as reduction is not needed for single process operation
34+
"""
35+
return tensor
2436

2537
@property
2638
def root_device(self) -> torch.device:

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,15 @@ def is_global_zero(self) -> bool:
5555
"""Whether the current process is the rank zero process not only on the local node, but for all nodes."""
5656

5757
@abstractmethod
58-
def reduce(self, output: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
59-
"""Reduces the given output (e.g. across GPUs/Processes)"""
58+
def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
59+
"""
60+
Reduces the given tensor (e.g. across GPUs/processes).
61+
62+
Args:
63+
tensor: the tensor to sync and reduce
64+
*args: plugin-specific positional arguments
65+
**kwargs: plugin-specific keyword arguments
66+
"""
6067

6168
@abstractmethod
6269
def barrier(self, name: Optional[str] = None) -> None:

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ def handle_given_plugins(
163163

164164
for plug in plugins:
165165
if isinstance(plug, str):
166+
# Reset the distributed type as the user has overridden training type
167+
# via the plugins argument
168+
self._distrib_type = None
166169
self.set_distributed_mode(plug)
167170

168171
elif isinstance(plug, TrainingTypePlugin):
@@ -196,7 +199,6 @@ def handle_given_plugins(
196199
)
197200

198201
self._training_type_plugin = training_type
199-
self._training_type_plugin = self.training_type_plugin
200202
self._precision_plugin = precision
201203
self._cluster_environment = cluster_environment or self.select_cluster_environment()
202204

tests/accelerators/test_accelerator_connector.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@
2323
from pytorch_lightning.accelerators.cpu import CPUAccelerator
2424
from pytorch_lightning.accelerators.gpu import GPUAccelerator
2525
from pytorch_lightning.callbacks import Callback
26-
from pytorch_lightning.plugins import DDP2Plugin, DDPPlugin, DDPSpawnPlugin, PrecisionPlugin, SingleDevicePlugin
26+
from pytorch_lightning.plugins import (
27+
DDP2Plugin,
28+
DDPPlugin,
29+
DDPShardedPlugin,
30+
DDPSpawnPlugin,
31+
PrecisionPlugin,
32+
SingleDevicePlugin,
33+
)
2734
from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment
2835
from tests.helpers.boring_model import BoringModel
2936

@@ -378,3 +385,18 @@ def on_fit_start(self, trainer, pl_module):
378385

379386
with pytest.raises(SystemExit):
380387
trainer.fit(model)
388+
389+
390+
@pytest.mark.parametrize(
391+
["accelerator", "plugin"],
392+
[('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')],
393+
)
394+
def test_plugin_accelerator_choice(accelerator, plugin):
395+
"""
396+
Ensure that when a plugin and accelerator is passed in, that the plugin takes precedent.
397+
"""
398+
trainer = Trainer(accelerator=accelerator, plugins=plugin, num_processes=2)
399+
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)
400+
401+
trainer = Trainer(plugins=plugin, num_processes=2)
402+
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)

tests/checkpointing/test_legacy_checkpoints.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"1.1.6",
5353
"1.1.7",
5454
"1.1.8",
55+
"1.2.0",
5556
]
5657
)
5758
def test_resume_legacy_checkpoints(tmpdir, pl_version):

0 commit comments

Comments
 (0)