26
26
from pytorch_lightning .utilities .exceptions import MisconfigurationException
27
27
from pytorch_lightning .utilities .warnings import WarningCache
28
28
29
+ warning_cache = WarningCache ()
30
+
29
31
_WANDB_AVAILABLE = _module_available ("wandb" )
30
32
31
33
try :
@@ -56,7 +58,6 @@ class WandbLogger(LightningLoggerBase):
56
58
project: The name of the project to which this run will belong.
57
59
log_model: Save checkpoints in wandb dir to upload on W&B servers.
58
60
prefix: A string to put at the beginning of metric keys.
59
- sync_step: Sync Trainer step with wandb step.
60
61
experiment: WandB experiment object. Automatically set when creating a run.
61
62
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
62
63
:func:`wandb.init` can be passed as keyword arguments in this logger.
@@ -92,7 +93,7 @@ def __init__(
92
93
log_model : Optional [bool ] = False ,
93
94
experiment = None ,
94
95
prefix : Optional [str ] = '' ,
95
- sync_step : Optional [bool ] = True ,
96
+ sync_step : Optional [bool ] = None ,
96
97
** kwargs
97
98
):
98
99
if wandb is None :
@@ -108,6 +109,12 @@ def __init__(
108
109
'Hint: Set `offline=False` to log your model.'
109
110
)
110
111
112
+ if sync_step is not None :
113
+ warning_cache .warn (
114
+ "`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5."
115
+ " Metrics are now logged separately and automatically synchronized." , DeprecationWarning
116
+ )
117
+
111
118
super ().__init__ ()
112
119
self ._name = name
113
120
self ._save_dir = save_dir
@@ -117,12 +124,8 @@ def __init__(
117
124
self ._project = project
118
125
self ._log_model = log_model
119
126
self ._prefix = prefix
120
- self ._sync_step = sync_step
121
127
self ._experiment = experiment
122
128
self ._kwargs = kwargs
123
- # logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
124
- self ._step_offset = 0
125
- self .warning_cache = WarningCache ()
126
129
127
130
def __getstate__ (self ):
128
131
state = self .__dict__ .copy ()
@@ -159,12 +162,15 @@ def experiment(self) -> Run:
159
162
** self ._kwargs
160
163
) if wandb .run is None else wandb .run
161
164
162
- # offset logging step when resuming a run
163
- self ._step_offset = self ._experiment .step
164
-
165
165
# save checkpoints in wandb dir to upload on W&B servers
166
166
if self ._save_dir is None :
167
167
self ._save_dir = self ._experiment .dir
168
+
169
+ # define default x-axis (for latest wandb versions)
170
+ if getattr (self ._experiment , "define_metric" , None ):
171
+ self ._experiment .define_metric ("trainer/global_step" )
172
+ self ._experiment .define_metric ("*" , step_metric = 'trainer/global_step' , step_sync = True )
173
+
168
174
return self ._experiment
169
175
170
176
def watch (self , model : nn .Module , log : str = 'gradients' , log_freq : int = 100 ):
@@ -182,15 +188,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
182
188
assert rank_zero_only .rank == 0 , 'experiment tried to log from global_rank != 0'
183
189
184
190
metrics = self ._add_prefix (metrics )
185
- if self ._sync_step and step is not None and step + self ._step_offset < self .experiment .step :
186
- self .warning_cache .warn (
187
- 'Trying to log at a previous step. Use `WandbLogger(sync_step=False)`'
188
- ' or try logging with `commit=False` when calling manually `wandb.log`.'
189
- )
190
- if self ._sync_step :
191
- self .experiment .log (metrics , step = (step + self ._step_offset ) if step is not None else None )
192
- elif step is not None :
193
- self .experiment .log ({** metrics , 'trainer_step' : (step + self ._step_offset )})
191
+ if step is not None :
192
+ self .experiment .log ({** metrics , 'trainer/global_step' : step })
194
193
else :
195
194
self .experiment .log (metrics )
196
195
@@ -210,10 +209,6 @@ def version(self) -> Optional[str]:
210
209
211
210
@rank_zero_only
212
211
def finalize (self , status : str ) -> None :
213
- # offset future training logged on same W&B run
214
- if self ._experiment is not None :
215
- self ._step_offset = self ._experiment .step
216
-
217
212
# upload all checkpoints from saving dir
218
213
if self ._log_model :
219
214
wandb .save (os .path .join (self .save_dir , "*.ckpt" ))
0 commit comments