Skip to content

Commit 3e5584e

Browse files
borisdaymacarmoccatchatonawaelchli
authored andcommitted
fix(wandb): prevent WandbLogger from dropping values (Lightning-AI#5931)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: chaton <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 1cebe78 commit 3e5584e

File tree

5 files changed

+33
-53
lines changed

5 files changed

+33
-53
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))
3131

3232

33+
- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931))
34+
35+
3336
## [1.2.1] - 2021-02-23
3437

3538
### Fixed

pytorch_lightning/loggers/wandb.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2727
from pytorch_lightning.utilities.warnings import WarningCache
2828

29+
warning_cache = WarningCache()
30+
2931
_WANDB_AVAILABLE = _module_available("wandb")
3032

3133
try:
@@ -56,7 +58,6 @@ class WandbLogger(LightningLoggerBase):
5658
project: The name of the project to which this run will belong.
5759
log_model: Save checkpoints in wandb dir to upload on W&B servers.
5860
prefix: A string to put at the beginning of metric keys.
59-
sync_step: Sync Trainer step with wandb step.
6061
experiment: WandB experiment object. Automatically set when creating a run.
6162
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
6263
:func:`wandb.init` can be passed as keyword arguments in this logger.
@@ -92,7 +93,7 @@ def __init__(
9293
log_model: Optional[bool] = False,
9394
experiment=None,
9495
prefix: Optional[str] = '',
95-
sync_step: Optional[bool] = True,
96+
sync_step: Optional[bool] = None,
9697
**kwargs
9798
):
9899
if wandb is None:
@@ -108,6 +109,12 @@ def __init__(
108109
'Hint: Set `offline=False` to log your model.'
109110
)
110111

112+
if sync_step is not None:
113+
warning_cache.warn(
114+
"`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5."
115+
" Metrics are now logged separately and automatically synchronized.", DeprecationWarning
116+
)
117+
111118
super().__init__()
112119
self._name = name
113120
self._save_dir = save_dir
@@ -117,12 +124,8 @@ def __init__(
117124
self._project = project
118125
self._log_model = log_model
119126
self._prefix = prefix
120-
self._sync_step = sync_step
121127
self._experiment = experiment
122128
self._kwargs = kwargs
123-
# logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
124-
self._step_offset = 0
125-
self.warning_cache = WarningCache()
126129

127130
def __getstate__(self):
128131
state = self.__dict__.copy()
@@ -159,12 +162,15 @@ def experiment(self) -> Run:
159162
**self._kwargs
160163
) if wandb.run is None else wandb.run
161164

162-
# offset logging step when resuming a run
163-
self._step_offset = self._experiment.step
164-
165165
# save checkpoints in wandb dir to upload on W&B servers
166166
if self._save_dir is None:
167167
self._save_dir = self._experiment.dir
168+
169+
# define default x-axis (for latest wandb versions)
170+
if getattr(self._experiment, "define_metric", None):
171+
self._experiment.define_metric("trainer/global_step")
172+
self._experiment.define_metric("*", step_metric='trainer/global_step', step_sync=True)
173+
168174
return self._experiment
169175

170176
def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100):
@@ -182,15 +188,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
182188
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'
183189

184190
metrics = self._add_prefix(metrics)
185-
if self._sync_step and step is not None and step + self._step_offset < self.experiment.step:
186-
self.warning_cache.warn(
187-
'Trying to log at a previous step. Use `WandbLogger(sync_step=False)`'
188-
' or try logging with `commit=False` when calling manually `wandb.log`.'
189-
)
190-
if self._sync_step:
191-
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
192-
elif step is not None:
193-
self.experiment.log({**metrics, 'trainer_step': (step + self._step_offset)})
191+
if step is not None:
192+
self.experiment.log({**metrics, 'trainer/global_step': step})
194193
else:
195194
self.experiment.log(metrics)
196195

@@ -210,10 +209,6 @@ def version(self) -> Optional[str]:
210209

211210
@rank_zero_only
212211
def finalize(self, status: str) -> None:
213-
# offset future training logged on same W&B run
214-
if self._experiment is not None:
215-
self._step_offset = self._experiment.step
216-
217212
# upload all checkpoints from saving dir
218213
if self._log_model:
219214
wandb.save(os.path.join(self.save_dir, "*.ckpt"))

tests/deprecated_api/test_remove_1-5.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,22 @@
1313
# limitations under the License.
1414
"""Test deprecated functionality which will be removed in v1.5.0"""
1515

16+
from unittest import mock
17+
1618
import pytest
1719

1820
from pytorch_lightning import Trainer, Callback
21+
from pytorch_lightning.loggers import WandbLogger
1922
from tests.helpers import BoringModel
2023
from tests.helpers.utils import no_warning_call
2124

2225

26+
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
27+
def test_v1_5_0_wandb_unused_sync_step(tmpdir):
28+
with pytest.deprecated_call(match=r"v1.2.1 and will be removed in v1.5"):
29+
WandbLogger(sync_step=True)
30+
31+
2332
def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir):
2433
class OldSignature(Callback):
2534
def on_save_checkpoint(self, trainer, pl_module): # noqa

tests/loggers/test_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,4 +404,4 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
404404
wandb.run = None
405405
wandb.init().step = 0
406406
logger.log_metrics({"test": 1.0}, step=0)
407-
logger.experiment.log.assert_called_once_with({'tmp-test': 1.0}, step=0)
407+
logger.experiment.log.assert_called_once_with({'tmp-test': 1.0, 'trainer/global_step': 0})

tests/loggers/test_wandb.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,7 @@ def test_wandb_logger_init(wandb, recwarn):
4141
logger = WandbLogger()
4242
logger.log_metrics({'acc': 1.0})
4343
wandb.init.assert_called_once()
44-
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)
45-
46-
# test sync_step functionality
47-
wandb.init().log.reset_mock()
48-
wandb.init.reset_mock()
49-
wandb.run = None
50-
wandb.init().step = 0
51-
logger = WandbLogger(sync_step=False)
52-
logger.log_metrics({'acc': 1.0})
5344
wandb.init().log.assert_called_once_with({'acc': 1.0})
54-
wandb.init().log.reset_mock()
55-
logger.log_metrics({'acc': 1.0}, step=3)
56-
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer_step': 3})
57-
58-
# mock wandb step
59-
wandb.init().step = 0
6045

6146
# test wandb.init not called if there is a W&B run
6247
wandb.init().log.reset_mock()
@@ -65,13 +50,12 @@ def test_wandb_logger_init(wandb, recwarn):
6550
logger = WandbLogger()
6651
logger.log_metrics({'acc': 1.0}, step=3)
6752
wandb.init.assert_called_once()
68-
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3)
53+
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer/global_step': 3})
6954

7055
# continue training on same W&B run and offset step
71-
wandb.init().step = 3
7256
logger.finalize('success')
73-
logger.log_metrics({'acc': 1.0}, step=3)
74-
wandb.init().log.assert_called_with({'acc': 1.0}, step=6)
57+
logger.log_metrics({'acc': 1.0}, step=6)
58+
wandb.init().log.assert_called_with({'acc': 1.0, 'trainer/global_step': 6})
7559

7660
# log hyper parameters
7761
logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
@@ -88,17 +72,6 @@ def test_wandb_logger_init(wandb, recwarn):
8872
logger.watch('model', 'log', 10)
8973
wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)
9074

91-
# verify warning for logging at a previous step
92-
assert 'Trying to log at a previous step' not in get_warnings(recwarn)
93-
# current step from wandb should be 6 (last logged step)
94-
logger.experiment.step = 6
95-
# logging at step 2 should raise a warning (step_offset is still 3)
96-
logger.log_metrics({'acc': 1.0}, step=2)
97-
assert 'Trying to log at a previous step' in get_warnings(recwarn)
98-
# logging again at step 2 should not display again the same warning
99-
logger.log_metrics({'acc': 1.0}, step=2)
100-
assert 'Trying to log at a previous step' not in get_warnings(recwarn)
101-
10275
assert logger.name == wandb.init().project_name()
10376
assert logger.version == wandb.init().id
10477

0 commit comments

Comments
 (0)