Skip to content

Commit cbdf2a8

Browse files
committed
fix running stage access
1 parent bab7691 commit cbdf2a8

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

tests/models/test_restore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def on_train_start(self):
453453
# haven't trained with the new loaded model
454454
dp_model = new_trainer.model
455455
dp_model.eval()
456-
dp_model.module.module.running_stage = RunningStage.EVALUATING
456+
new_trainer._running_stage = RunningStage.EVALUATING
457457

458458
dataloader = self.train_dataloader()
459459
tpipes.run_prediction(self.trainer.lightning_module, dataloader)

tests/overrides/test_data_parallel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import MagicMock
1+
from unittest.mock import MagicMock, Mock
22

33
import pytest
44
import torch
@@ -103,7 +103,8 @@ def training_step(self, batch, batch_idx):
103103
return {"loss": loss}
104104

105105
model = TestModel()
106-
model.running_stage = RunningStage.TRAINING
106+
model.trainer = Mock()
107+
model.trainer._running_stage = RunningStage.TRAINING
107108
batch = torch.rand(2, 32).cuda()
108109
batch_idx = 0
109110

@@ -146,7 +147,8 @@ def training_step(self, batch, batch_idx):
146147

147148
model = TestModel()
148149
model.to(device)
149-
model.running_stage = RunningStage.TRAINING
150+
model.trainer = Mock()
151+
model.trainer._running_stage = RunningStage.TRAINING
150152
batch = torch.rand(2, 32).to(device)
151153
batch_idx = 0
152154

0 commit comments

Comments
 (0)