Skip to content

Commit 73ef543

Browse files
tchatonlexierule
authored andcommitted
[bugfix] Perform reduction for dict in training_step and DP (#6324)
* fix * update * update * add changelog * Update CHANGELOG.md Co-authored-by: Carlos Mocholí <[email protected]> * Update tests/accelerators/test_dp.py Co-authored-by: Carlos Mocholí <[email protected]> * update changelog Co-authored-by: Carlos Mocholí <[email protected]> (cherry picked from commit 248a8e8)
1 parent a434725 commit 73ef543

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
169169
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)
170170

171171

172+
- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324))
173+
174+
172175
## [1.2.1] - 2021-02-23
173176

174177
### Fixed

pytorch_lightning/plugins/training_type/dp.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pytorch_lightning.core.step_result import Result
2020
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
2121
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
22+
from pytorch_lightning.utilities.apply_func import apply_to_collection
2223

2324

2425
class DataParallelPlugin(ParallelPlugin):
@@ -31,14 +32,30 @@ def setup(self, model):
3132
model.to(self.root_device)
3233
self._model = DataParallel(LightningParallelModule(model), self.parallel_devices)
3334

34-
def reduce(self, output, *args, **kwargs):
35-
if isinstance(output, Result):
36-
output.dp_reduce()
35+
def reduce(self, tensor, *args, **kwargs):
36+
"""
37+
Reduces a tensor from all parallel processes to one aggregated tensor.
3738
38-
elif isinstance(output, torch.Tensor):
39-
output = output.mean()
39+
Args:
40+
tensor: the tensor to sync and reduce
41+
*args: ignored for DP
42+
**kwargs: ignored for DP
4043
41-
return output
44+
Return:
45+
reduced value, except when the input was not a tensor the output remains is unchanged
46+
"""
47+
if isinstance(tensor, Result):
48+
tensor.dp_reduce()
49+
50+
else:
51+
52+
def _reduce(tensor: torch.Tensor):
53+
dtype_tensor = tensor.dtype
54+
return tensor.float().mean().type(dtype_tensor)
55+
56+
tensor = apply_to_collection(tensor, torch.Tensor, _reduce)
57+
58+
return tensor
4259

4360
@property
4461
def root_device(self):

0 commit comments

Comments
 (0)