19
19
from pytorch_lightning .core .step_result import Result
20
20
from pytorch_lightning .overrides .data_parallel import LightningParallelModule
21
21
from pytorch_lightning .plugins .training_type .parallel import ParallelPlugin
22
+ from pytorch_lightning .utilities .apply_func import apply_to_collection
22
23
23
24
24
25
class DataParallelPlugin (ParallelPlugin ):
@@ -31,14 +32,30 @@ def setup(self, model):
31
32
model .to (self .root_device )
32
33
self ._model = DataParallel (LightningParallelModule (model ), self .parallel_devices )
33
34
34
- def reduce (self , output , * args , ** kwargs ):
35
- if isinstance ( output , Result ):
36
- output . dp_reduce ()
35
+ def reduce (self , tensor , * args , ** kwargs ):
36
+ """
37
+ Reduces a tensor from all parallel processes to one aggregated tensor.
37
38
38
- elif isinstance (output , torch .Tensor ):
39
- output = output .mean ()
39
+ Args:
40
+ tensor: the tensor to sync and reduce
41
+ *args: ignored for DP
42
+ **kwargs: ignored for DP
40
43
41
- return output
44
+ Return:
45
+ reduced value, except when the input was not a tensor the output remains is unchanged
46
+ """
47
+ if isinstance (tensor , Result ):
48
+ tensor .dp_reduce ()
49
+
50
+ else :
51
+
52
+ def _reduce (tensor : torch .Tensor ):
53
+ dtype_tensor = tensor .dtype
54
+ return tensor .float ().mean ().type (dtype_tensor )
55
+
56
+ tensor = apply_to_collection (tensor , torch .Tensor , _reduce )
57
+
58
+ return tensor
42
59
43
60
@property
44
61
def root_device (self ):
0 commit comments