Skip to content

Commit 0544efd

Browse files
tchatoncarmocca
andauthored
[bug] Update broadcast + reduce decision ModelCheckpoint] (#6410)
* resolve bug * update * update changelog * update PR * Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py Co-authored-by: Carlos Mocholí <[email protected]> * add todo * resolve issues * resolve flake8 * update * add coverage for reduce * wip * restore back to brodbact * remove test.py * resolve flake8 * update * check world size * resolve test * update * use pytorch version when defined * update on comments * update on comments * flake8 * resolve bugs * Update CHANGELOG.md Co-authored-by: Carlos Mocholí <[email protected]> * update * update * update * update * remove test * update * resolve flake8 * update * update * update * proxy * update * update * resolve typo * prune * update parallel * update Co-authored-by: Carlos Mocholí <[email protected]>
1 parent dcd9dd8 commit 0544efd

File tree

21 files changed

+296
-144
lines changed

21 files changed

+296
-144
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
128128
- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688))
129129

130130

131+
- Fixed broacast to use PyTorch `broadcast_object_list` and add `reduce_decision` ([#6410](https://github.com/PyTorchLightning/pytorch-lightning/pull/6410))
132+
133+
131134
- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380))
132135

133136

pytorch_lightning/accelerators/accelerator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
2323
from pytorch_lightning.trainer.states import TrainerState
2424
from pytorch_lightning.utilities.apply_func import move_data_to_device
25-
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
2625
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
2726

2827
if TYPE_CHECKING:
@@ -405,7 +404,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
405404
Return:
406405
A tensor of shape (world_size, batch, ...)
407406
"""
408-
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
407+
return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads)
409408

410409
def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
411410
"""Wraps the dataloader if necessary

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,4 @@ def _run_early_stopping_check(self, trainer):
172172
trainer.should_stop = True
173173

174174
# stop every ddp process if any world process decides to stop
175-
trainer.should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(trainer.should_stop)
175+
trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop)

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def _do_save(self, trainer, filepath: str):
424424
else:
425425
raise ValueError(".save_function() not set")
426426

427-
def check_monitor_top_k(self, current: torch.Tensor) -> bool:
427+
def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor] = None) -> bool:
428428
if current is None:
429429
return False
430430

@@ -444,7 +444,12 @@ def check_monitor_top_k(self, current: torch.Tensor) -> bool:
444444
current = torch.tensor(current)
445445

446446
monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode]
447-
return monitor_op(current, self.best_k_models[self.kth_best_model_path]).item()
447+
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])
448+
449+
# If using multiple devices, make sure all processes are unanimous on the decision.
450+
should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save)
451+
452+
return should_update_best_and_save
448453

449454
@classmethod
450455
def _format_checkpoint_name(
@@ -638,15 +643,7 @@ def _save_top_k_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]):
638643
epoch = monitor_candidates.get("epoch")
639644
step = monitor_candidates.get("step")
640645

641-
# when `val_loss` is being logged and no ModelCheckpoint is being provided
642-
# `val_loss` will be selected for monitor and need to be reduced to
643-
# prevent processes divergence
644-
# TODO: Move this logic to logger_connector. This also needs to be fixed for any
645-
# other monitor logged value which aren't produced from a Metric.
646-
if self.monitor == "val_loss":
647-
current = trainer.training_type_plugin.reduce(current, reduce_op="mean")
648-
649-
if self.check_monitor_top_k(current):
646+
if self.check_monitor_top_k(trainer, current):
650647
self._update_best_and_save(current, epoch, step, trainer, monitor_candidates)
651648
elif self.verbose:
652649
rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}")
@@ -731,5 +728,4 @@ def file_exists(self, filepath: Union[str, Path], trainer) -> bool:
731728
the internal state to diverge between ranks.
732729
"""
733730
exists = self._fs.exists(filepath)
734-
exists = trainer.training_type_plugin.broadcast(exists)
735-
return exists
731+
return trainer.training_type_plugin.broadcast(exists)

pytorch_lightning/distributed/dist.py

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import io
1514
from typing import Any
1615

17-
import torch
18-
from torch import distributed as torch_distrib
19-
20-
from pytorch_lightning.utilities import _GROUP_AVAILABLE
21-
22-
WORLD = None
23-
if _GROUP_AVAILABLE:
24-
from torch.distributed import group
25-
WORLD = group.WORLD
16+
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
17+
from pytorch_lightning.utilities.distributed import group as _group
2618

2719

2820
class LightningDistributed:
@@ -31,32 +23,13 @@ def __init__(self, rank=None, device=None):
3123
self.rank = rank
3224
self.device = device
3325

34-
def broadcast(self, obj: Any, group=WORLD):
35-
if self.rank == 0:
36-
self._emit(obj, group)
37-
else:
38-
obj = self._receive(group)
39-
return obj
40-
41-
def _broadcast(self, tensor, src=0, group=WORLD):
42-
if group is None:
43-
return torch_distrib.broadcast(tensor, src=src)
44-
return torch_distrib.broadcast(tensor, src=0, group=group)
45-
46-
def _emit(self, obj: Any, group=WORLD):
47-
buffer = io.BytesIO()
48-
torch.save(obj, buffer)
49-
data = bytearray(buffer.getbuffer())
50-
length_tensor = torch.tensor([len(data)]).long().to(self.device)
51-
self._broadcast(length_tensor, src=0, group=group)
52-
data_tensor = torch.ByteTensor(data).to(self.device)
53-
self._broadcast(data_tensor, src=0, group=group)
54-
55-
def _receive(self, group=WORLD):
56-
length_tensor = torch.tensor([0]).long().to(self.device)
57-
self._broadcast(length_tensor, src=0, group=group)
58-
data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device)
59-
self._broadcast(data_tensor, src=0, group=group)
60-
buffer = io.BytesIO(data_tensor.cpu().numpy())
61-
obj = torch.load(buffer)
62-
return obj
26+
def broadcast(self, obj: Any, group=_group.WORLD):
27+
# always wrap into a list so list can be brodcasted.
28+
obj = [obj]
29+
30+
if self.rank != 0:
31+
obj = [None] * len(obj)
32+
33+
broadcast_object_list(obj, 0, group=group or _group.WORLD)
34+
35+
return obj[0]
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import logging
2+
import pickle
3+
4+
import torch
5+
6+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7
7+
8+
log = logging.getLogger(__name__)
9+
10+
if torch.distributed.is_available():
11+
from torch.distributed import Backend, broadcast, get_backend, get_rank, GroupMember
12+
13+
# The code underneath is taken from PyTorch ``torch/distributed/distributed_c10d.py``
14+
# and enable broadcasting for PyTorch 1.6 and lower.
15+
16+
17+
# https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L160
18+
def _rank_not_in_group(group):
19+
"""
20+
Helper that checks if the current process's rank is not in a given group.
21+
"""
22+
if group is None:
23+
return False
24+
return group == GroupMember.NON_GROUP_MEMBER
25+
26+
27+
# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1164
28+
def _object_to_tensor(obj):
29+
buffer = pickle.dumps(obj)
30+
byte_storage = torch.ByteStorage.from_buffer(buffer) # type: ignore[attr-defined]
31+
byte_tensor = torch.ByteTensor(byte_storage)
32+
local_size = torch.LongTensor([byte_tensor.numel()])
33+
return byte_tensor, local_size
34+
35+
36+
# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py
37+
def _tensor_to_object(tensor, tensor_size):
38+
buf = tensor.numpy().tobytes()[:tensor_size]
39+
out = pickle.loads(buf)
40+
return out
41+
42+
43+
# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1327
44+
def _broadcast_object_list(object_list, src=0, group=None):
45+
if _rank_not_in_group(group):
46+
return
47+
48+
my_rank = get_rank()
49+
# Serialize object_list elements to tensors on src rank.
50+
if my_rank == src:
51+
tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list])
52+
object_sizes_tensor = torch.cat(size_list)
53+
else:
54+
object_sizes_tensor = torch.LongTensor(len(object_list))
55+
56+
group_backend = get_backend(group)
57+
is_nccl_backend = group_backend == Backend.NCCL
58+
current_device = torch.device("cpu")
59+
if is_nccl_backend:
60+
# See note about using torch.cuda.current_device() here in docstring.
61+
# We cannot simply use my_rank since rank == device is not necessarily
62+
# true.
63+
current_device = torch.device('cuda', torch.cuda.current_device())
64+
object_sizes_tensor = object_sizes_tensor.to(current_device)
65+
object_sizes_tensor = object_sizes_tensor.to(current_device)
66+
67+
# Broadcast object sizes
68+
broadcast(object_sizes_tensor, src=src, group=group)
69+
70+
# Concatenate and broadcast serialized object tensors
71+
if my_rank == src:
72+
object_tensor = torch.cat(tensor_list)
73+
else:
74+
object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item())
75+
76+
if is_nccl_backend:
77+
object_tensor = object_tensor.to(current_device)
78+
79+
broadcast(object_tensor, src=src, group=group)
80+
81+
# Deserialize objects using their stored sizes.
82+
offset = 0
83+
if my_rank != src:
84+
for i, obj_size in enumerate(object_sizes_tensor):
85+
obj_view = object_tensor[offset:offset + obj_size]
86+
obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload]
87+
offset += obj_size
88+
object_list[i] = _tensor_to_object(obj_view, obj_size)
89+
90+
91+
if _TORCH_GREATER_EQUAL_1_7 and torch.distributed.is_available():
92+
from torch.distributed.distributed_c10d import broadcast_object_list
93+
else:
94+
broadcast_object_list = _broadcast_object_list

pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,5 +169,4 @@ def pre_optimizer_step(
169169
pl_module.trainer.call_hook("on_after_backward")
170170

171171
optimizer.step(**kwargs)
172-
173172
return False

pytorch_lightning/plugins/training_type/dp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def barrier(self, *args, **kwargs):
7171
def broadcast(self, obj: object, src: int = 0) -> object:
7272
return obj
7373

74-
def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
75-
return should_stop
74+
def reduce_boolean_decision(self, decision: bool) -> bool:
75+
return decision
7676

7777
def training_step(self, *args, **kwargs):
7878
return self.model(*args, **kwargs)

pytorch_lightning/plugins/training_type/horovod.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pytorch_lightning.core.optimizer import LightningOptimizer
2222
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
2323
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
24-
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
24+
from pytorch_lightning.utilities.distributed import group, rank_zero_only, ReduceOp
2525

2626
if _HOROVOD_AVAILABLE:
2727
import horovod.torch as hvd
@@ -159,8 +159,13 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[
159159
hvd.join()
160160
return hvd.allreduce(tensor, op=reduce_op)
161161

162-
def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None):
163-
if group is not None:
162+
def all_gather(
163+
self,
164+
result: Union[torch.Tensor],
165+
group: Optional[Any] = group.WORLD,
166+
sync_grads: bool = False
167+
) -> torch.Tensor:
168+
if group is not None and group != group.WORLD:
164169
raise ValueError(
165170
"Horovod does not support allgather using a subcommunicator at this time. "
166171
"Unset `group`."

pytorch_lightning/plugins/training_type/parallel.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import io
1514
import os
1615
from abc import ABC, abstractmethod
1716
from contextlib import contextmanager
18-
from typing import List, Optional
17+
from typing import Any, List, Optional
1918

2019
import torch
2120
from torch.nn.parallel import DistributedDataParallel
@@ -36,9 +35,10 @@ def __init__(
3635
):
3736
super().__init__()
3837
self.parallel_devices = parallel_devices
38+
self.cluster_environment = cluster_environment
39+
self.global_rank = 0
3940
self.world_size = 1
4041
self.local_rank = 0
41-
self.cluster_environment = cluster_environment
4242

4343
@property
4444
@abstractmethod
@@ -70,11 +70,15 @@ def distributed_sampler_kwargs(self):
7070
distributed_sampler_kwargs = dict(num_replicas=len(self.parallel_devices), rank=self.global_rank)
7171
return distributed_sampler_kwargs
7272

73-
def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
74-
should_stop = torch.tensor(int(should_stop), device=self.lightning_module.device)
75-
should_stop = self.reduce(should_stop, reduce_op=ReduceOp.SUM)
76-
should_stop = bool(should_stop == self.world_size)
77-
return should_stop
73+
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
74+
"""Perform a all_gather on all processes """
75+
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
76+
77+
def reduce_boolean_decision(self, decision: bool) -> bool:
78+
decision = torch.tensor(int(decision), device=self.lightning_module.device)
79+
decision = self.reduce(decision, reduce_op=ReduceOp.SUM)
80+
decision = bool(decision == self.world_size)
81+
return decision
7882

7983
@property
8084
def torch_distributed_backend(self):
@@ -112,13 +116,3 @@ def block_backward_sync(self):
112116
yield None
113117
else:
114118
yield None
115-
116-
def broadcast(self, obj: object, src: int) -> object:
117-
buffer = io.BytesIO()
118-
torch.save(obj, buffer)
119-
data = bytearray(buffer.getbuffer())
120-
data_tensor = torch.tensor(data).to(self.root_device, dtype=torch.float)
121-
data = all_gather_ddp_if_available(data_tensor)
122-
buffer = io.BytesIO(data.cpu().byte().numpy())
123-
obj = torch.load(buffer)
124-
return obj

pytorch_lightning/plugins/training_type/single_device.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Union
14+
from typing import Any, Optional, Union
1515

1616
import torch
1717

@@ -23,6 +23,9 @@ class SingleDevicePlugin(TrainingTypePlugin):
2323
def __init__(self, device: torch.device):
2424
super().__init__()
2525
self.device: torch.device = device
26+
self.global_rank = 0
27+
self.local_rank = 0
28+
self.world_size = 1
2629

2730
@property
2831
def on_tpu(self) -> bool:
@@ -47,6 +50,10 @@ def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) ->
4750
"""
4851
return tensor
4952

53+
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
54+
"""Perform a all_gather on all processes """
55+
return tensor
56+
5057
@property
5158
def root_device(self) -> torch.device:
5259
return self.device

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,11 @@ def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
203203
model.trainer.save_checkpoint(path)
204204
return path
205205

206-
def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
207-
should_stop = torch.tensor(int(should_stop), device=self.lightning_module.device)
208-
stop = xm.mesh_reduce('stop_signal', should_stop, sum)
209-
rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
210-
should_stop = int(stop.item()) == self.world_size
211-
return should_stop
206+
def reduce_decision(self, decision: bool) -> bool:
207+
decision = torch.tensor(int(decision), device=self.device)
208+
decision = self.reduce(decision, "sum")
209+
decision = bool(decision == self.world_size)
210+
return decision
212211

213212
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
214213
if not isinstance(output, torch.Tensor):

0 commit comments

Comments
 (0)