Skip to content

Commit 3bdc067

Browse files
awaelchliBorda
andauthored
consistent behavior for reduce method across all Plugins (#6011)
* reduction docs * docs for abstract base method * make mean the default * add preliminary chlog Co-authored-by: Jirka Borovec <[email protected]>
1 parent f2660ac commit 3bdc067

File tree

8 files changed

+111
-29
lines changed

8 files changed

+111
-29
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424
- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080))
2525

2626

27+
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)
28+
29+
30+
2731
## [1.2.0] - 2021-02-18
2832

2933
### Added

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,10 +278,22 @@ def model_to_device(self):
278278
torch.cuda.set_device(self.root_device)
279279
self.model.to(self.root_device)
280280

281-
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
282-
if isinstance(output, torch.Tensor):
283-
output = sync_ddp_if_available(output, group, reduce_op)
284-
return output
281+
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
282+
"""
283+
Reduces a tensor from several distributed processes to one aggregated tensor.
284+
285+
Args:
286+
tensor: the tensor to sync and reduce
287+
group: the process group to gather results from. Defaults to all processes (world)
288+
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
289+
Can also be a string 'sum' to calculate the sum during reduction.
290+
291+
Return:
292+
reduced value, except when the input was not a tensor the output remains is unchanged
293+
"""
294+
if isinstance(tensor, torch.Tensor):
295+
tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean"))
296+
return tensor
285297

286298
def training_step(self, *args, **kwargs):
287299
return self.model(*args, **kwargs)

pytorch_lightning/plugins/training_type/ddp2.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,26 @@ def setup(self, model):
2525
self.task_idx = self.cluster_environment.local_rank()
2626
# the difference to DDP is that we don't call children processes here
2727

28-
def reduce(self, output, *args, **kwargs):
29-
if isinstance(output, Result):
30-
output.dp_reduce()
28+
def reduce(self, tensor, *args, **kwargs):
29+
"""
30+
Reduces a tensor from all processes to one aggregated tensor.
31+
In DDP2, the reduction here is only across local devices within the node.
3132
32-
elif isinstance(output, torch.Tensor):
33-
output = output.mean()
33+
Args:
34+
tensor: the tensor to sync and reduce
35+
*args: ignored for DDP2
36+
**kwargs: ignored for DDP2
3437
35-
return output
38+
Return:
39+
reduced value, except when the input was not a tensor the output remains is unchanged
40+
"""
41+
if isinstance(tensor, Result):
42+
tensor.dp_reduce()
43+
44+
elif isinstance(tensor, torch.Tensor):
45+
tensor = tensor.mean()
46+
47+
return tensor
3648

3749
@property
3850
def root_device(self):

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,22 @@ def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, opti
256256
if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync:
257257
prepare_for_backward(self.model, closure_loss)
258258

259-
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
260-
if isinstance(output, torch.Tensor):
261-
output = sync_ddp_if_available(output, group, reduce_op)
262-
return output
259+
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
260+
"""
261+
Reduces a tensor from several distributed processes to one aggregated tensor.
262+
263+
Args:
264+
tensor: the tensor to sync and reduce
265+
group: the process group to gather results from. Defaults to all processes (world)
266+
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
267+
Can also be a string 'sum' to calculate the sum during reduction.
268+
269+
Return:
270+
reduced value, except when the input was not a tensor the output remains is unchanged
271+
"""
272+
if isinstance(tensor, torch.Tensor):
273+
tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean"))
274+
return tensor
263275

264276
def training_step(self, *args, **kwargs):
265277
return self.model(*args, **kwargs)

pytorch_lightning/plugins/training_type/dp.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,25 @@ def setup(self, model):
3131
model.to(self.root_device)
3232
self._model = DataParallel(LightningParallelModule(model), self.parallel_devices)
3333

34-
def reduce(self, output, *args, **kwargs):
35-
if isinstance(output, Result):
36-
output.dp_reduce()
34+
def reduce(self, tensor, *args, **kwargs):
35+
"""
36+
Reduces a tensor from all parallel processes to one aggregated tensor.
3737
38-
elif isinstance(output, torch.Tensor):
39-
output = output.mean()
38+
Args:
39+
tensor: the tensor to sync and reduce
40+
*args: ignored for DP
41+
**kwargs: ignored for DP
4042
41-
return output
43+
Return:
44+
reduced value, except when the input was not a tensor the output remains is unchanged
45+
"""
46+
if isinstance(tensor, Result):
47+
tensor.dp_reduce()
48+
49+
elif isinstance(tensor, torch.Tensor):
50+
tensor = tensor.mean()
51+
52+
return tensor
4253

4354
@property
4455
def root_device(self):

pytorch_lightning/plugins/training_type/horovod.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,23 +127,35 @@ def model_to_device(self):
127127
torch.cuda.set_device(self.root_device)
128128
self.model.to(self.root_device)
129129

130-
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
130+
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
131+
"""
132+
Reduces a tensor from several distributed processes to one aggregated tensor.
133+
134+
Args:
135+
tensor: the tensor to sync and reduce
136+
group: the process group to gather results from. Defaults to all processes (world)
137+
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
138+
Can also be a string 'sum' to calculate the sum during reduction.
139+
140+
Return:
141+
reduced value, except when the input was not a tensor the output remains is unchanged
142+
"""
131143
if group is not None:
132144
raise ValueError(
133145
"Horovod does not support allreduce using a subcommunicator at this time. "
134146
"Unset `group`."
135147
)
136148

137-
if reduce_op is None or reduce_op == "sum":
138-
reduce_op = hvd.Sum
139-
elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
149+
if reduce_op in (None, "avg", "mean"):
140150
reduce_op = hvd.Average
151+
elif reduce_op == "sum":
152+
reduce_op = hvd.Sum
141153
else:
142154
raise ValueError(f"unrecognized `reduce_op`: {reduce_op}")
143155

144156
# sync all processes before reduction
145157
hvd.join()
146-
return hvd.allreduce(output, op=reduce_op)
158+
return hvd.allreduce(tensor, op=reduce_op)
147159

148160
def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None):
149161
if group is not None:

pytorch_lightning/plugins/training_type/single_device.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,20 @@ def on_tpu(self) -> bool:
1919
def on_gpu(self) -> bool:
2020
return self.device.type == "cuda" and torch.cuda.is_available()
2121

22-
def reduce(self, output: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
23-
return output
22+
def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]:
23+
"""
24+
Reduces a tensor from several distributed processes to one aggregated tensor.
25+
As this plugin only operates with a single device, the reduction is simply the identity.
26+
27+
Args:
28+
tensor: the tensor to sync and reduce
29+
*args: ignored
30+
**kwargs: ignored
31+
32+
Return:
33+
the unmodified input as reduction is not needed for single process operation
34+
"""
35+
return tensor
2436

2537
@property
2638
def root_device(self) -> torch.device:

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,15 @@ def is_global_zero(self) -> bool:
5555
"""Whether the current process is the rank zero process not only on the local node, but for all nodes."""
5656

5757
@abstractmethod
58-
def reduce(self, output: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
59-
"""Reduces the given output (e.g. across GPUs/Processes)"""
58+
def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
59+
"""
60+
Reduces the given tensor (e.g. across GPUs/processes).
61+
62+
Args:
63+
tensor: the tensor to sync and reduce
64+
*args: plugin-specific positional arguments
65+
**kwargs: plugin-specific keyword arguments
66+
"""
6067

6168
@abstractmethod
6269
def barrier(self, name: Optional[str] = None) -> None:

0 commit comments

Comments
 (0)