Skip to content

Commit 484dce1

Browse files
authored
[bugfix] TPU + all_gather + SingleTPU shouldn't call xm.all_gather (Lightning-AI#6296)
* resolve an issue with TPU * update * add changelog
1 parent 4a8422c commit 484dce1

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9595
- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))
9696

9797

98+
- Fixed `SingleTPU` calling `all_gather` ([#6296](https://github.com/PyTorchLightning/pytorch-lightning/pull/6296))
99+
100+
98101
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)
99102

100103

pytorch_lightning/accelerators/tpu.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
4444
Return:
4545
A tensor of shape (world_size, batch, ...)
4646
"""
47-
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
47+
# todo: Add support for backward with all_gather
48+
if torch.distributed.is_initialized():
49+
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
50+
return tensor

0 commit comments

Comments
 (0)