Skip to content

Commit 423ecf9

Browse files
asnorkintchatontadejsvananyahjha93Alexander Snorkin
authored
Feature/5275 clean progress bar print (#5470)
* Trainer.test should return only test metrics (#5214) * resolve bug * merge tests * Fix metric state reset (#5273) * Fix metric state reset * Fix test * Improve formatting Co-authored-by: Ananya Harsh Jha <[email protected]> * print() method added to ProgressBar * printing alongside progress bar added to LightningModule.print() * LightningModule.print() method documentation updated * ProgressBarBase.print() stub added * stub * add progress bar tests * fix isort * Progress Callback fixes * test_metric.py duplicate DummyList removed * PEP and isort fixes * CHANGELOG updated * test_progress_bar_print win linesep fix * test_progress_bar.py remove whitespaces * Update CHANGELOG.md Co-authored-by: chaton <[email protected]> Co-authored-by: Tadej Svetina <[email protected]> Co-authored-by: Ananya Harsh Jha <[email protected]> Co-authored-by: Alexander Snorkin <[email protected]> Co-authored-by: rohitgr7 <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 57215b7 commit 423ecf9

File tree

4 files changed

+102
-4
lines changed

4 files changed

+102
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8989
- Added DeepSpeed integration ([#5954](https://github.com/PyTorchLightning/pytorch-lightning/pull/5954),
9090
[#6042](https://github.com/PyTorchLightning/pytorch-lightning/pull/6042))
9191

92+
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))
93+
94+
9295
### Changed
9396

9497
- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))

pytorch_lightning/callbacks/progress.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
2020
"""
2121
import importlib
22+
import io
23+
import os
2224
import sys
2325

2426
# check if ipywidgets is installed before importing tqdm.auto
@@ -187,6 +189,12 @@ def enable(self):
187189
"""
188190
raise NotImplementedError
189191

192+
def print(self, *args, **kwargs):
193+
"""
194+
You should provide a way to print without breaking the progress bar.
195+
"""
196+
print(*args, **kwargs)
197+
190198
def on_init_end(self, trainer):
191199
self._trainer = trainer
192200

@@ -451,6 +459,22 @@ def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, da
451459
def on_predict_end(self, trainer, pl_module):
452460
self.predict_progress_bar.close()
453461

462+
def print(
463+
self, *args, sep: str = ' ', end: str = os.linesep, file: Optional[io.TextIOBase] = None, nolock: bool = False
464+
):
465+
active_progress_bar = None
466+
467+
if not self.main_progress_bar.disable:
468+
active_progress_bar = self.main_progress_bar
469+
elif not self.val_progress_bar.disable:
470+
active_progress_bar = self.val_progress_bar
471+
elif not self.test_progress_bar.disable:
472+
active_progress_bar = self.test_progress_bar
473+
474+
if active_progress_bar is not None:
475+
s = sep.join(map(str, args))
476+
active_progress_bar.write(s, end=end, file=file, nolock=nolock)
477+
454478
def _should_update(self, current, total):
455479
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
456480

pytorch_lightning/core/lightning.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,8 @@ def print(self, *args, **kwargs) -> None:
189189
Prints only from process 0. Use this in any distributed mode to log only once.
190190
191191
Args:
192-
*args: The thing to print. Will be passed to Python's built-in print function.
193-
**kwargs: Will be passed to Python's built-in print function.
192+
*args: The thing to print. The same as for Python's built-in print function.
193+
**kwargs: The same as for Python's built-in print function.
194194
195195
Example::
196196
@@ -199,7 +199,11 @@ def forward(self, x):
199199
200200
"""
201201
if self.trainer.is_global_zero:
202-
print(*args, **kwargs)
202+
progress_bar = self.trainer.progress_bar_callback
203+
if progress_bar is not None and progress_bar.is_enabled:
204+
progress_bar.print(*args, **kwargs)
205+
else:
206+
print(*args, **kwargs)
203207

204208
def log(
205209
self,

tests/callbacks/test_progress_bar.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
import sys
1516
from unittest import mock
16-
from unittest.mock import call, Mock
17+
from unittest.mock import ANY, call, Mock
1718

1819
import pytest
1920
import torch
@@ -381,3 +382,69 @@ def training_step(self, batch, batch_idx):
381382
def test_tqdm_format_num(input_num, expected):
382383
""" Check that the specialized tqdm.format_num appends 0 to floats and strings """
383384
assert tqdm.format_num(input_num) == expected
385+
386+
387+
class PrintModel(BoringModel):
388+
389+
def training_step(self, *args, **kwargs):
390+
self.print("training_step", end="")
391+
return super().training_step(*args, **kwargs)
392+
393+
def validation_step(self, *args, **kwargs):
394+
self.print("validation_step", file=sys.stderr)
395+
return super().validation_step(*args, **kwargs)
396+
397+
def test_step(self, *args, **kwargs):
398+
self.print("test_step")
399+
return super().test_step(*args, **kwargs)
400+
401+
402+
@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write")
403+
def test_progress_bar_print(tqdm_write, tmpdir):
404+
""" Test that printing in the LightningModule redirects arguments to the progress bar. """
405+
model = PrintModel()
406+
bar = ProgressBar()
407+
trainer = Trainer(
408+
default_root_dir=tmpdir,
409+
num_sanity_val_steps=0,
410+
limit_train_batches=1,
411+
limit_val_batches=1,
412+
limit_test_batches=1,
413+
max_steps=1,
414+
callbacks=[bar],
415+
)
416+
trainer.fit(model)
417+
trainer.test(model)
418+
assert tqdm_write.call_count == 3
419+
assert tqdm_write.call_args_list == [
420+
call("training_step", end="", file=None, nolock=False),
421+
call("validation_step", end=os.linesep, file=sys.stderr, nolock=False),
422+
call("test_step", end=os.linesep, file=None, nolock=False),
423+
]
424+
425+
426+
@mock.patch('builtins.print')
427+
@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write")
428+
def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir):
429+
""" Test that printing in LightningModule goes through built-in print functin when progress bar is disabled. """
430+
model = PrintModel()
431+
bar = ProgressBar()
432+
trainer = Trainer(
433+
default_root_dir=tmpdir,
434+
num_sanity_val_steps=0,
435+
limit_train_batches=1,
436+
limit_val_batches=1,
437+
limit_test_batches=1,
438+
max_steps=1,
439+
callbacks=[bar],
440+
)
441+
bar.disable()
442+
trainer.fit(model)
443+
trainer.test(model)
444+
445+
mock_print.assert_has_calls([
446+
call("training_step", end=""),
447+
call("validation_step", file=ANY),
448+
call("test_step"),
449+
])
450+
tqdm_write.assert_not_called()

0 commit comments

Comments
 (0)