Skip to content

Commit 40d5a9d

Browse files
borisdaymacarmoccatchatonawaelchli
authored
fix(wandb): prevent WandbLogger from dropping values (#5931)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: chaton <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent ee5032a commit 40d5a9d

File tree

5 files changed

+33
-53
lines changed

5 files changed

+33
-53
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6363
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))
6464

6565

66+
- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931))
67+
68+
6669
## [1.2.1] - 2021-02-23
6770

6871
### Fixed

pytorch_lightning/loggers/wandb.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2727
from pytorch_lightning.utilities.warnings import WarningCache
2828

29+
warning_cache = WarningCache()
30+
2931
_WANDB_AVAILABLE = _module_available("wandb")
3032

3133
try:
@@ -56,7 +58,6 @@ class WandbLogger(LightningLoggerBase):
5658
project: The name of the project to which this run will belong.
5759
log_model: Save checkpoints in wandb dir to upload on W&B servers.
5860
prefix: A string to put at the beginning of metric keys.
59-
sync_step: Sync Trainer step with wandb step.
6061
experiment: WandB experiment object. Automatically set when creating a run.
6162
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
6263
:func:`wandb.init` can be passed as keyword arguments in this logger.
@@ -98,7 +99,7 @@ def __init__(
9899
log_model: Optional[bool] = False,
99100
experiment=None,
100101
prefix: Optional[str] = '',
101-
sync_step: Optional[bool] = True,
102+
sync_step: Optional[bool] = None,
102103
**kwargs
103104
):
104105
if wandb is None:
@@ -114,6 +115,12 @@ def __init__(
114115
'Hint: Set `offline=False` to log your model.'
115116
)
116117

118+
if sync_step is not None:
119+
warning_cache.warn(
120+
"`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5."
121+
" Metrics are now logged separately and automatically synchronized.", DeprecationWarning
122+
)
123+
117124
super().__init__()
118125
self._name = name
119126
self._save_dir = save_dir
@@ -123,12 +130,8 @@ def __init__(
123130
self._project = project
124131
self._log_model = log_model
125132
self._prefix = prefix
126-
self._sync_step = sync_step
127133
self._experiment = experiment
128134
self._kwargs = kwargs
129-
# logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
130-
self._step_offset = 0
131-
self.warning_cache = WarningCache()
132135

133136
def __getstate__(self):
134137
state = self.__dict__.copy()
@@ -165,12 +168,15 @@ def experiment(self) -> Run:
165168
**self._kwargs
166169
) if wandb.run is None else wandb.run
167170

168-
# offset logging step when resuming a run
169-
self._step_offset = self._experiment.step
170-
171171
# save checkpoints in wandb dir to upload on W&B servers
172172
if self._save_dir is None:
173173
self._save_dir = self._experiment.dir
174+
175+
# define default x-axis (for latest wandb versions)
176+
if getattr(self._experiment, "define_metric", None):
177+
self._experiment.define_metric("trainer/global_step")
178+
self._experiment.define_metric("*", step_metric='trainer/global_step', step_sync=True)
179+
174180
return self._experiment
175181

176182
def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100):
@@ -188,15 +194,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
188194
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'
189195

190196
metrics = self._add_prefix(metrics)
191-
if self._sync_step and step is not None and step + self._step_offset < self.experiment.step:
192-
self.warning_cache.warn(
193-
'Trying to log at a previous step. Use `WandbLogger(sync_step=False)`'
194-
' or try logging with `commit=False` when calling manually `wandb.log`.'
195-
)
196-
if self._sync_step:
197-
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
198-
elif step is not None:
199-
self.experiment.log({**metrics, 'trainer_step': (step + self._step_offset)})
197+
if step is not None:
198+
self.experiment.log({**metrics, 'trainer/global_step': step})
200199
else:
201200
self.experiment.log(metrics)
202201

@@ -216,10 +215,6 @@ def version(self) -> Optional[str]:
216215

217216
@rank_zero_only
218217
def finalize(self, status: str) -> None:
219-
# offset future training logged on same W&B run
220-
if self._experiment is not None:
221-
self._step_offset = self._experiment.step
222-
223218
# upload all checkpoints from saving dir
224219
if self._log_model:
225220
wandb.save(os.path.join(self.save_dir, "*.ckpt"))

tests/deprecated_api/test_remove_1-5.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,22 @@
1313
# limitations under the License.
1414
"""Test deprecated functionality which will be removed in v1.5.0"""
1515

16+
from unittest import mock
17+
1618
import pytest
1719

1820
from pytorch_lightning import Trainer, Callback
21+
from pytorch_lightning.loggers import WandbLogger
1922
from tests.helpers import BoringModel
2023
from tests.helpers.utils import no_warning_call
2124

2225

26+
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
27+
def test_v1_5_0_wandb_unused_sync_step(tmpdir):
28+
with pytest.deprecated_call(match=r"v1.2.1 and will be removed in v1.5"):
29+
WandbLogger(sync_step=True)
30+
31+
2332
def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir):
2433
class OldSignature(Callback):
2534
def on_save_checkpoint(self, trainer, pl_module): # noqa

tests/loggers/test_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,4 +404,4 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
404404
wandb.run = None
405405
wandb.init().step = 0
406406
logger.log_metrics({"test": 1.0}, step=0)
407-
logger.experiment.log.assert_called_once_with({'tmp-test': 1.0}, step=0)
407+
logger.experiment.log.assert_called_once_with({'tmp-test': 1.0, 'trainer/global_step': 0})

tests/loggers/test_wandb.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,7 @@ def test_wandb_logger_init(wandb, recwarn):
4141
logger = WandbLogger()
4242
logger.log_metrics({'acc': 1.0})
4343
wandb.init.assert_called_once()
44-
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)
45-
46-
# test sync_step functionality
47-
wandb.init().log.reset_mock()
48-
wandb.init.reset_mock()
49-
wandb.run = None
50-
wandb.init().step = 0
51-
logger = WandbLogger(sync_step=False)
52-
logger.log_metrics({'acc': 1.0})
5344
wandb.init().log.assert_called_once_with({'acc': 1.0})
54-
wandb.init().log.reset_mock()
55-
logger.log_metrics({'acc': 1.0}, step=3)
56-
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer_step': 3})
57-
58-
# mock wandb step
59-
wandb.init().step = 0
6045

6146
# test wandb.init not called if there is a W&B run
6247
wandb.init().log.reset_mock()
@@ -65,13 +50,12 @@ def test_wandb_logger_init(wandb, recwarn):
6550
logger = WandbLogger()
6651
logger.log_metrics({'acc': 1.0}, step=3)
6752
wandb.init.assert_called_once()
68-
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3)
53+
wandb.init().log.assert_called_once_with({'acc': 1.0, 'trainer/global_step': 3})
6954

7055
# continue training on same W&B run and offset step
71-
wandb.init().step = 3
7256
logger.finalize('success')
73-
logger.log_metrics({'acc': 1.0}, step=3)
74-
wandb.init().log.assert_called_with({'acc': 1.0}, step=6)
57+
logger.log_metrics({'acc': 1.0}, step=6)
58+
wandb.init().log.assert_called_with({'acc': 1.0, 'trainer/global_step': 6})
7559

7660
# log hyper parameters
7761
logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
@@ -88,17 +72,6 @@ def test_wandb_logger_init(wandb, recwarn):
8872
logger.watch('model', 'log', 10)
8973
wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)
9074

91-
# verify warning for logging at a previous step
92-
assert 'Trying to log at a previous step' not in get_warnings(recwarn)
93-
# current step from wandb should be 6 (last logged step)
94-
logger.experiment.step = 6
95-
# logging at step 2 should raise a warning (step_offset is still 3)
96-
logger.log_metrics({'acc': 1.0}, step=2)
97-
assert 'Trying to log at a previous step' in get_warnings(recwarn)
98-
# logging again at step 2 should not display again the same warning
99-
logger.log_metrics({'acc': 1.0}, step=2)
100-
assert 'Trying to log at a previous step' not in get_warnings(recwarn)
101-
10275
assert logger.name == wandb.init().project_name()
10376
assert logger.version == wandb.init().id
10477

0 commit comments

Comments
 (0)