Skip to content

Commit 0456b45

Browse files
authored
mini refactor for _running_stage access (#5724)
* running stage * circular import * running stage cleanup * fix unused import * fix running stage access * add return type * Revert "add return type" This reverts commit 65b0fe2. * try fix typing
1 parent 423ecf9 commit 0456b45

File tree

5 files changed

+24
-27
lines changed

5 files changed

+24
-27
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from argparse import Namespace
2525
from functools import partial
2626
from pathlib import Path
27-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
27+
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
2828

2929
import torch
3030
from torch import ScriptModule, Tensor
@@ -44,6 +44,9 @@
4444
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4545
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
4646

47+
if TYPE_CHECKING:
48+
from pytorch_lightning.trainer.states import RunningStage
49+
4750

4851
class LightningModule(
4952
ABC,
@@ -103,7 +106,6 @@ def __init__(self, *args, **kwargs):
103106
self._running_manual_backward = False
104107
self._current_hook_fx_name = None
105108
self._current_dataloader_idx = None
106-
self.running_stage = None
107109
self._automatic_optimization: bool = True
108110

109111
def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
@@ -169,6 +171,10 @@ def automatic_optimization(self) -> bool:
169171
"""
170172
return self._automatic_optimization
171173

174+
@property
175+
def running_stage(self) -> Optional["RunningStage"]:
176+
return self.trainer._running_stage if self.trainer else None
177+
172178
@automatic_optimization.setter
173179
def automatic_optimization(self, automatic_optimization: bool) -> None:
174180
self._automatic_optimization = automatic_optimization

pytorch_lightning/trainer/trainer.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
6060
from pytorch_lightning.utilities.cloud_io import load as pl_load
6161
from pytorch_lightning.utilities.debugging import InternalDebugger
62-
from pytorch_lightning.utilities.enums import LightningEnum
6362
from pytorch_lightning.utilities.exceptions import MisconfigurationException
6463
from pytorch_lightning.utilities.memory import recursive_detach
6564
from pytorch_lightning.utilities.model_helpers import is_overridden
@@ -450,7 +449,7 @@ def fit(
450449
# bookkeeping
451450
# we reuse fit in .test() and .predict(). When already set, it shouldn't be modified.
452451
if self._running_stage is None:
453-
self._set_running_stage(RunningStage.TRAINING, model)
452+
self._running_stage = RunningStage.TRAINING
454453

455454
# set local properties on the model
456455
self.model_connector.copy_trainer_model_properties(model)
@@ -531,7 +530,7 @@ def fit(
531530
if self._state != TrainerState.INTERRUPTED:
532531
self._state = TrainerState.FINISHED
533532

534-
self._set_running_stage(None, model)
533+
self._running_stage = None
535534

536535
return self.accelerator.results or 1
537536

@@ -564,14 +563,6 @@ def train_or_test_or_predict(self):
564563

565564
return results
566565

567-
def _set_running_stage(self, stage: LightningEnum, model_ref: LightningModule):
568-
"""
569-
This function is used to set the running_state on both
570-
the trainer and the model
571-
"""
572-
model_ref.running_stage = stage
573-
self._running_stage = stage
574-
575566
def _pre_training_routine(self):
576567
# wait for all to join if on distributed
577568
self.accelerator.barrier("setup_training")
@@ -614,7 +605,7 @@ def run_train(self):
614605
self.run_sanity_check(self.lightning_module)
615606

616607
# set stage for logging
617-
self._set_running_stage(RunningStage.TRAINING, self.lightning_module)
608+
self._running_stage = RunningStage.TRAINING
618609

619610
self.checkpoint_connector.has_trained = False
620611

@@ -678,9 +669,7 @@ def run_train(self):
678669
def run_evaluation(self, max_batches=None, on_epoch=False):
679670

680671
# used to know if we are logging for val, test + reset cached results
681-
self._set_running_stage(
682-
RunningStage.TESTING if self.testing else RunningStage.EVALUATING, self.lightning_module
683-
)
672+
self._running_stage = RunningStage.TESTING if self.testing else RunningStage.EVALUATING
684673
self.logger_connector.reset()
685674

686675
# bookkeeping
@@ -907,7 +896,7 @@ def test(
907896
# --------------------
908897
self.verbose_test = verbose
909898

910-
self._set_running_stage(RunningStage.TESTING, model or self.lightning_module)
899+
self._running_stage = RunningStage.TESTING
911900

912901
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
913902
if test_dataloaders and datamodule:
@@ -924,7 +913,7 @@ def test(
924913
results = self.__test_using_best_weights(ckpt_path, test_dataloaders)
925914

926915
self.teardown('test')
927-
self._set_running_stage(None, model or self.lightning_module)
916+
self._running_stage = None
928917
return results
929918

930919
def __test_using_best_weights(self, ckpt_path, test_dataloaders):
@@ -1016,7 +1005,7 @@ def predict(
10161005

10171006
model = model or self.lightning_module
10181007

1019-
self._set_running_stage(RunningStage.PREDICTING, model)
1008+
self._running_stage = RunningStage.PREDICTING
10201009

10211010
if dataloaders and datamodule:
10221011
raise MisconfigurationException(
@@ -1033,7 +1022,7 @@ def predict(
10331022

10341023
self.model = model
10351024
results = self.fit(model)
1036-
self._set_running_stage(None, model)
1025+
self._running_stage = None
10371026

10381027
return results
10391028

pytorch_lightning/trainer/training_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ def run_training_epoch(self):
517517
self.trainer.run_evaluation()
518518

519519
# reset stage to train
520-
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)
520+
self.trainer._running_stage = RunningStage.TRAINING
521521

522522
# -----------------------------------------
523523
# SAVE LOGGERS (ie: Tensorboard, etc...)
@@ -564,7 +564,7 @@ def run_training_epoch(self):
564564
self.trainer.run_evaluation(on_epoch=True)
565565

566566
# reset stage to train
567-
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)
567+
self.trainer._running_stage = RunningStage.TRAINING
568568

569569
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
570570
should_train_only = self.trainer.disable_validation or should_skip_eval

tests/models/test_restore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def on_train_start(self):
453453
# haven't trained with the new loaded model
454454
dp_model = new_trainer.model
455455
dp_model.eval()
456-
dp_model.module.module.running_stage = RunningStage.EVALUATING
456+
new_trainer._running_stage = RunningStage.EVALUATING
457457

458458
dataloader = self.train_dataloader()
459459
tpipes.run_prediction(self.trainer.lightning_module, dataloader)

tests/overrides/test_data_parallel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import MagicMock
1+
from unittest.mock import MagicMock, Mock
22

33
import pytest
44
import torch
@@ -103,7 +103,8 @@ def training_step(self, batch, batch_idx):
103103
return {"loss": loss}
104104

105105
model = TestModel()
106-
model.running_stage = RunningStage.TRAINING
106+
model.trainer = Mock()
107+
model.trainer._running_stage = RunningStage.TRAINING
107108
batch = torch.rand(2, 32).cuda()
108109
batch_idx = 0
109110

@@ -146,7 +147,8 @@ def training_step(self, batch, batch_idx):
146147

147148
model = TestModel()
148149
model.to(device)
149-
model.running_stage = RunningStage.TRAINING
150+
model.trainer = Mock()
151+
model.trainer._running_stage = RunningStage.TRAINING
150152
batch = torch.rand(2, 32).to(device)
151153
batch_idx = 0
152154

0 commit comments

Comments
 (0)