26
26
from pytorch_lightning .utilities .exceptions import MisconfigurationException
27
27
from pytorch_lightning .utilities .warnings import WarningCache
28
28
29
+ warning_cache = WarningCache ()
30
+
29
31
_WANDB_AVAILABLE = _module_available ("wandb" )
30
32
31
33
try :
@@ -56,7 +58,6 @@ class WandbLogger(LightningLoggerBase):
56
58
project: The name of the project to which this run will belong.
57
59
log_model: Save checkpoints in wandb dir to upload on W&B servers.
58
60
prefix: A string to put at the beginning of metric keys.
59
- sync_step: Sync Trainer step with wandb step.
60
61
experiment: WandB experiment object. Automatically set when creating a run.
61
62
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
62
63
:func:`wandb.init` can be passed as keyword arguments in this logger.
@@ -98,7 +99,7 @@ def __init__(
98
99
log_model : Optional [bool ] = False ,
99
100
experiment = None ,
100
101
prefix : Optional [str ] = '' ,
101
- sync_step : Optional [bool ] = True ,
102
+ sync_step : Optional [bool ] = None ,
102
103
** kwargs
103
104
):
104
105
if wandb is None :
@@ -114,6 +115,12 @@ def __init__(
114
115
'Hint: Set `offline=False` to log your model.'
115
116
)
116
117
118
+ if sync_step is not None :
119
+ warning_cache .warn (
120
+ "`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5."
121
+ " Metrics are now logged separately and automatically synchronized." , DeprecationWarning
122
+ )
123
+
117
124
super ().__init__ ()
118
125
self ._name = name
119
126
self ._save_dir = save_dir
@@ -123,12 +130,8 @@ def __init__(
123
130
self ._project = project
124
131
self ._log_model = log_model
125
132
self ._prefix = prefix
126
- self ._sync_step = sync_step
127
133
self ._experiment = experiment
128
134
self ._kwargs = kwargs
129
- # logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
130
- self ._step_offset = 0
131
- self .warning_cache = WarningCache ()
132
135
133
136
def __getstate__ (self ):
134
137
state = self .__dict__ .copy ()
@@ -165,12 +168,15 @@ def experiment(self) -> Run:
165
168
** self ._kwargs
166
169
) if wandb .run is None else wandb .run
167
170
168
- # offset logging step when resuming a run
169
- self ._step_offset = self ._experiment .step
170
-
171
171
# save checkpoints in wandb dir to upload on W&B servers
172
172
if self ._save_dir is None :
173
173
self ._save_dir = self ._experiment .dir
174
+
175
+ # define default x-axis (for latest wandb versions)
176
+ if getattr (self ._experiment , "define_metric" , None ):
177
+ self ._experiment .define_metric ("trainer/global_step" )
178
+ self ._experiment .define_metric ("*" , step_metric = 'trainer/global_step' , step_sync = True )
179
+
174
180
return self ._experiment
175
181
176
182
def watch (self , model : nn .Module , log : str = 'gradients' , log_freq : int = 100 ):
@@ -188,15 +194,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
188
194
assert rank_zero_only .rank == 0 , 'experiment tried to log from global_rank != 0'
189
195
190
196
metrics = self ._add_prefix (metrics )
191
- if self ._sync_step and step is not None and step + self ._step_offset < self .experiment .step :
192
- self .warning_cache .warn (
193
- 'Trying to log at a previous step. Use `WandbLogger(sync_step=False)`'
194
- ' or try logging with `commit=False` when calling manually `wandb.log`.'
195
- )
196
- if self ._sync_step :
197
- self .experiment .log (metrics , step = (step + self ._step_offset ) if step is not None else None )
198
- elif step is not None :
199
- self .experiment .log ({** metrics , 'trainer_step' : (step + self ._step_offset )})
197
+ if step is not None :
198
+ self .experiment .log ({** metrics , 'trainer/global_step' : step })
200
199
else :
201
200
self .experiment .log (metrics )
202
201
@@ -216,10 +215,6 @@ def version(self) -> Optional[str]:
216
215
217
216
@rank_zero_only
218
217
def finalize (self , status : str ) -> None :
219
- # offset future training logged on same W&B run
220
- if self ._experiment is not None :
221
- self ._step_offset = self ._experiment .step
222
-
223
218
# upload all checkpoints from saving dir
224
219
if self ._log_model :
225
220
wandb .save (os .path .join (self .save_dir , "*.ckpt" ))
0 commit comments