Skip to content

Commit 6bb24c2

Browse files
tchatonlexierule
authored andcommitted
[bug] Update broadcast + reduce decision ModelCheckpoint] (#6410)
* resolve bug * update * update changelog * update PR * Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py Co-authored-by: Carlos Mocholí <[email protected]> * add todo * resolve issues * resolve flake8 * update * add coverage for reduce * wip * restore back to brodbact * remove test.py * resolve flake8 * update * check world size * resolve test * update * use pytorch version when defined * update on comments * update on comments * flake8 * resolve bugs * Update CHANGELOG.md Co-authored-by: Carlos Mocholí <[email protected]> * update * update * update * update * remove test * update * resolve flake8 * update * update * update * proxy * update * update * resolve typo * prune * update parallel * update Co-authored-by: Carlos Mocholí <[email protected]> (cherry picked from commit 0544efd)
1 parent 4b762a9 commit 6bb24c2

File tree

21 files changed

+345
-153
lines changed

21 files changed

+345
-153
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
128128
- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688))
129129

130130

131+
- Fixed broacast to use PyTorch `broadcast_object_list` and add `reduce_decision` ([#6410](https://github.com/PyTorchLightning/pytorch-lightning/pull/6410))
132+
133+
131134
- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380))
132135

133136

pytorch_lightning/accelerators/accelerator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
2222
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
2323
from pytorch_lightning.utilities.apply_func import move_data_to_device
24-
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
2524
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
2625

2726

@@ -396,7 +395,7 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s
396395
Return:
397396
A tensor of shape (world_size, batch, ...)
398397
"""
399-
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
398+
return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads)
400399

401400
def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
402401
"""Wraps the dataloader if necessary

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,4 @@ def _run_early_stopping_check(self, trainer, pl_module):
190190
trainer.should_stop = True
191191

192192
# stop every ddp process if any world process decides to stop
193-
trainer.should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(trainer.should_stop)
193+
trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop)

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def _save_model(self, filepath: str, trainer, pl_module):
336336
else:
337337
raise ValueError(".save_function() not set")
338338

339-
def check_monitor_top_k(self, current) -> bool:
339+
def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor] = None) -> bool:
340340
if current is None:
341341
return False
342342

@@ -356,7 +356,12 @@ def check_monitor_top_k(self, current) -> bool:
356356
current = torch.tensor(current)
357357

358358
monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode]
359-
return monitor_op(current, self.best_k_models[self.kth_best_model_path]).item()
359+
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])
360+
361+
# If using multiple devices, make sure all processes are unanimous on the decision.
362+
should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save)
363+
364+
return should_update_best_and_save
360365

361366
@classmethod
362367
def _format_checkpoint_name(
@@ -554,15 +559,7 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
554559
epoch = metrics.get("epoch")
555560
step = metrics.get("step")
556561

557-
# when `val_loss` is being logged and no ModelCheckpoint is being provided
558-
# `val_loss` will be selected for monitor and need to be reduced to
559-
# prevent processes divergence
560-
# TODO: Move this logic to logger_connector. This also needs to be fixed for any
561-
# other monitor logged value which aren't produced from a Metric.
562-
if self.monitor == "val_loss":
563-
current = trainer.training_type_plugin.reduce(current, reduce_op="mean")
564-
565-
if self.check_monitor_top_k(current):
562+
if self.check_monitor_top_k(trainer, current):
566563
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)
567564
elif self.verbose:
568565
rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}")
@@ -627,5 +624,4 @@ def file_exists(self, filepath: Union[str, Path], trainer) -> bool:
627624
the internal state to diverge between ranks.
628625
"""
629626
exists = self._fs.exists(filepath)
630-
exists = trainer.training_type_plugin.broadcast(exists)
631-
return exists
627+
return trainer.training_type_plugin.broadcast(exists)

pytorch_lightning/distributed/dist.py

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import io
1514
from typing import Any
1615

17-
import torch
18-
from torch import distributed as torch_distrib
19-
20-
from pytorch_lightning.utilities import _GROUP_AVAILABLE
21-
22-
WORLD = None
23-
if _GROUP_AVAILABLE:
24-
from torch.distributed import group
25-
WORLD = group.WORLD
16+
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
17+
from pytorch_lightning.utilities.distributed import group as _group
2618

2719

2820
class LightningDistributed:
@@ -31,32 +23,13 @@ def __init__(self, rank=None, device=None):
3123
self.rank = rank
3224
self.device = device
3325

34-
def broadcast(self, obj: Any, group=WORLD):
35-
if self.rank == 0:
36-
self._emit(obj, group)
37-
else:
38-
obj = self._receive(group)
39-
return obj
40-
41-
def _broadcast(self, tensor, src=0, group=WORLD):
42-
if group is None:
43-
return torch_distrib.broadcast(tensor, src=src)
44-
return torch_distrib.broadcast(tensor, src=0, group=group)
45-
46-
def _emit(self, obj: Any, group=WORLD):
47-
buffer = io.BytesIO()
48-
torch.save(obj, buffer)
49-
data = bytearray(buffer.getbuffer())
50-
length_tensor = torch.tensor([len(data)]).long().to(self.device)
51-
self._broadcast(length_tensor, src=0, group=group)
52-
data_tensor = torch.ByteTensor(data).to(self.device)
53-
self._broadcast(data_tensor, src=0, group=group)
54-
55-
def _receive(self, group=WORLD):
56-
length_tensor = torch.tensor([0]).long().to(self.device)
57-
self._broadcast(length_tensor, src=0, group=group)
58-
data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device)
59-
self._broadcast(data_tensor, src=0, group=group)
60-
buffer = io.BytesIO(data_tensor.cpu().numpy())
61-
obj = torch.load(buffer)
62-
return obj
26+
def broadcast(self, obj: Any, group=_group.WORLD):
27+
# always wrap into a list so list can be brodcasted.
28+
obj = [obj]
29+
30+
if self.rank != 0:
31+
obj = [None] * len(obj)
32+
33+
broadcast_object_list(obj, 0, group=group or _group.WORLD)
34+
35+
return obj[0]
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import logging
2+
import pickle
3+
4+
import torch
5+
6+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7
7+
8+
log = logging.getLogger(__name__)
9+
10+
if torch.distributed.is_available():
11+
from torch.distributed import Backend, broadcast, get_backend, get_rank, GroupMember
12+
13+
# The code underneath is taken from PyTorch ``torch/distributed/distributed_c10d.py``
14+
# and enable broadcasting for PyTorch 1.6 and lower.
15+
16+
17+
# https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L160
18+
def _rank_not_in_group(group):
19+
"""
20+
Helper that checks if the current process's rank is not in a given group.
21+
"""
22+
if group is None:
23+
return False
24+
return group == GroupMember.NON_GROUP_MEMBER
25+
26+
27+
# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1164
28+
def _object_to_tensor(obj):
29+
buffer = pickle.dumps(obj)
30+
byte_storage = torch.ByteStorage.from_buffer(buffer) # type: ignore[attr-defined]
31+
byte_tensor = torch.ByteTensor(byte_storage)
32+
local_size = torch.LongTensor([byte_tensor.numel()])
33+
return byte_tensor, local_size
34+
35+
36+
# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py
37+
def _tensor_to_object(tensor, tensor_size):
38+
buf = tensor.numpy().tobytes()[:tensor_size]
39+
out = pickle.loads(buf)
40+
return out
41+
42+
43+
# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1327
44+
def _broadcast_object_list(object_list, src=0, group=None):
45+
if _rank_not_in_group(group):
46+
return
47+
48+
my_rank = get_rank()
49+
# Serialize object_list elements to tensors on src rank.
50+
if my_rank == src:
51+
tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list])
52+
object_sizes_tensor = torch.cat(size_list)
53+
else:
54+
object_sizes_tensor = torch.LongTensor(len(object_list))
55+
56+
group_backend = get_backend(group)
57+
is_nccl_backend = group_backend == Backend.NCCL
58+
current_device = torch.device("cpu")
59+
if is_nccl_backend:
60+
# See note about using torch.cuda.current_device() here in docstring.
61+
# We cannot simply use my_rank since rank == device is not necessarily
62+
# true.
63+
current_device = torch.device('cuda', torch.cuda.current_device())
64+
object_sizes_tensor = object_sizes_tensor.to(current_device)
65+
object_sizes_tensor = object_sizes_tensor.to(current_device)
66+
67+
# Broadcast object sizes
68+
broadcast(object_sizes_tensor, src=src, group=group)
69+
70+
# Concatenate and broadcast serialized object tensors
71+
if my_rank == src:
72+
object_tensor = torch.cat(tensor_list)
73+
else:
74+
object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item())
75+
76+
if is_nccl_backend:
77+
object_tensor = object_tensor.to(current_device)
78+
79+
broadcast(object_tensor, src=src, group=group)
80+
81+
# Deserialize objects using their stored sizes.
82+
offset = 0
83+
if my_rank != src:
84+
for i, obj_size in enumerate(object_sizes_tensor):
85+
obj_view = object_tensor[offset:offset + obj_size]
86+
obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload]
87+
offset += obj_size
88+
object_list[i] = _tensor_to_object(obj_view, obj_size)
89+
90+
91+
if _TORCH_GREATER_EQUAL_1_7 and torch.distributed.is_available():
92+
from torch.distributed.distributed_c10d import broadcast_object_list
93+
else:
94+
broadcast_object_list = _broadcast_object_list

pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Callable, List, Tuple
14+
from typing import Any, Callable, Generator, List, Sequence, Tuple, Type, TYPE_CHECKING
1515

1616
import torch
17-
from torch.optim import Optimizer
1817

1918
from pytorch_lightning.core import LightningModule
2019
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
@@ -23,37 +22,41 @@
2322
if _APEX_AVAILABLE:
2423
from apex import amp
2524

25+
if TYPE_CHECKING:
26+
from torch.optim import Optimizer
27+
2628

2729
class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
2830
"""Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""
2931

30-
def __init__(self, amp_level: str):
32+
def __init__(self, amp_level: str = "O2") -> None:
3133
self.backend = AMPType.APEX
3234
self.amp_level = amp_level
3335

34-
def master_params(self, optimizer: torch.optim.Optimizer):
36+
def master_params(self, optimizer: 'Optimizer') -> Generator[torch.Tensor, None, None]:
3537
return amp.master_params(optimizer)
3638

37-
def connect(self, model: torch.nn.Module, optimizers, lr_schedulers):
39+
def connect(self, model: torch.nn.Module, optimizers: Sequence['Optimizer'],
40+
lr_schedulers: Sequence[Any]) -> Tuple[torch.nn.Module, Sequence['Optimizer'], Sequence[Any]]:
3841
"""Connects the precision plugin to the training process,
3942
configures apex and reinits the schedulers
4043
"""
4144
if model.device.type != "cuda":
4245
return model, optimizers, lr_schedulers
43-
model, optimizers = self.configure_apex(amp, model, optimizers, self.amp_level)
46+
model, optimizers = self.configure_apex(amp, model, list(optimizers), self.amp_level)
4447
self.reinit_scheduler_properties(optimizers, lr_schedulers)
4548
return model, optimizers, lr_schedulers
4649

4750
def backward(
4851
self,
4952
model: LightningModule,
5053
closure_loss: torch.Tensor,
51-
optimizer: torch.optim.Optimizer,
54+
optimizer: 'Optimizer',
5255
opt_idx: int,
5356
should_accumulate: bool,
54-
*args,
55-
**kwargs,
56-
):
57+
*args: Any,
58+
**kwargs: Any,
59+
) -> torch.Tensor:
5760
"""performs the actual backpropagation
5861
5962
Args:
@@ -94,11 +97,11 @@ def backward(
9497

9598
def configure_apex(
9699
self,
97-
amp: object,
100+
amp: Type,
98101
model: LightningModule,
99-
optimizers: List[Optimizer],
102+
optimizers: List['Optimizer'],
100103
amp_level: str,
101-
) -> Tuple[LightningModule, List[Optimizer]]:
104+
) -> Tuple[LightningModule, List['Optimizer']]:
102105
r"""
103106
Override to init AMP your own way.
104107
Must return a model and list of optimizers.
@@ -127,7 +130,7 @@ def configure_apex(self, amp, model, optimizers, amp_level):
127130
return model, optimizers
128131

129132
@staticmethod
130-
def reinit_scheduler_properties(optimizers: list, schedulers: list):
133+
def reinit_scheduler_properties(optimizers: Sequence['Optimizer'], schedulers: Sequence[Any]) -> None:
131134
"""Reinitializes schedulers with correct properties"""
132135
# Reinitialize optimizer.step properties added by schedulers
133136
for scheduler in schedulers:
@@ -149,7 +152,12 @@ def reinit_scheduler_properties(optimizers: list, schedulers: list):
149152
break
150153

151154
def pre_optimizer_step(
152-
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs
155+
self,
156+
pl_module: LightningModule,
157+
optimizer: 'Optimizer',
158+
optimizer_idx: int,
159+
lambda_closure: Callable,
160+
**kwargs: Any,
153161
) -> bool:
154162
"""
155163
always called before the optimizer step.
@@ -160,6 +168,5 @@ def pre_optimizer_step(
160168
if not pl_module.automatic_optimization:
161169
pl_module.trainer.call_hook("on_after_backward")
162170

163-
optimizer.step()
164-
171+
optimizer.step(**kwargs)
165172
return False

pytorch_lightning/plugins/training_type/dp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def barrier(self, *args, **kwargs):
7171
def broadcast(self, obj: object, src: int = 0) -> object:
7272
return obj
7373

74-
def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
75-
return should_stop
74+
def reduce_boolean_decision(self, decision: bool) -> bool:
75+
return decision
7676

7777
def training_step(self, *args, **kwargs):
7878
return self.model(*args, **kwargs)

pytorch_lightning/plugins/training_type/horovod.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pytorch_lightning.core.optimizer import LightningOptimizer
2222
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
2323
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
24-
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
24+
from pytorch_lightning.utilities.distributed import group, rank_zero_only, ReduceOp
2525

2626
if _HOROVOD_AVAILABLE:
2727
import horovod.torch as hvd
@@ -147,8 +147,13 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
147147
hvd.join()
148148
return hvd.allreduce(output, op=reduce_op)
149149

150-
def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None):
151-
if group is not None:
150+
def all_gather(
151+
self,
152+
result: Union[torch.Tensor],
153+
group: Optional[Any] = group.WORLD,
154+
sync_grads: bool = False
155+
) -> torch.Tensor:
156+
if group is not None and group != group.WORLD:
152157
raise ValueError(
153158
"Horovod does not support allgather using a subcommunicator at this time. "
154159
"Unset `group`."

0 commit comments

Comments
 (0)