Skip to content

Commit 9028fe1

Browse files
committed
[bugfix] TPU + all_gather + SingleTPU shouldn't call xm.all_gather (#6296)
* resolve an issue with TPU * update * add changelog
1 parent ef03a03 commit 9028fe1

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030
- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))
3131

3232

33+
- Fixed `SingleTPU` calling `all_gather` ([#6296](https://github.com/PyTorchLightning/pytorch-lightning/pull/6296))
34+
3335
## [1.2.2] - 2021-03-02
3436

3537
### Added

pytorch_lightning/accelerators/tpu.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,7 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s
4040
Return:
4141
A tensor of shape (world_size, batch, ...)
4242
"""
43-
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
43+
# todo: Add support for backward with all_gather
44+
if torch.distributed.is_initialized():
45+
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
46+
return tensor

0 commit comments

Comments
 (0)