Skip to content

Commit 953c873

Browse files
kaushikb11lexierule
authored andcommitted
Fix: Train loop config validation was run during trainer.predict (#6541)
1 parent c954fbc commit 953c873

File tree

3 files changed

+44
-1
lines changed

3 files changed

+44
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
139139
- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541))
140140

141141

142+
- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541))
143+
144+
142145
## [1.2.3] - 2021-03-09
143146

144147
### Fixed

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ def verify_loop_configurations(self, model: LightningModule):
3030
model: The model to check the configuration.
3131
3232
"""
33-
if not self.trainer.testing:
33+
if self.trainer.predicting:
34+
self.__verify_predict_loop_configuration(model)
35+
elif not self.trainer.testing:
3436
self.__verify_train_loop_configuration(model)
3537
self.__verify_eval_loop_configuration(model, 'validation')
3638
else:
@@ -98,3 +100,9 @@ def __verify_eval_loop_configuration(self, model, eval_loop_name):
98100
rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop')
99101
if has_step and not has_loader:
100102
rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop')
103+
104+
def __verify_predict_loop_configuration(self, model):
105+
106+
has_predict_dataloader = is_overridden('predict_dataloader', model)
107+
if not has_predict_dataloader:
108+
raise MisconfigurationException('Dataloader not found for `Trainer.predict`')

tests/trainer/test_trainer.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1850,3 +1850,35 @@ def compare_optimizers():
18501850
trainer.max_epochs = 2 # simulate multiple fit calls
18511851
trainer.fit(model)
18521852
compare_optimizers()
1853+
1854+
1855+
@pytest.mark.parametrize("use_datamodule", [False, True])
1856+
def test_trainer_predict_verify_config(tmpdir, use_datamodule):
1857+
1858+
class TestModel(LightningModule):
1859+
1860+
def __init__(self):
1861+
super().__init__()
1862+
self.layer = torch.nn.Linear(32, 2)
1863+
1864+
def forward(self, x):
1865+
return self.layer(x)
1866+
1867+
dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))]
1868+
1869+
model = TestModel()
1870+
trainer = Trainer(default_root_dir=tmpdir)
1871+
1872+
if use_datamodule:
1873+
datamodule = TestLightningDataModule(dataloaders)
1874+
results = trainer.predict(model, datamodule=datamodule)
1875+
else:
1876+
results = trainer.predict(model, dataloaders=dataloaders)
1877+
1878+
assert len(results) == 2
1879+
assert results[0][0].shape == torch.Size([1, 2])
1880+
1881+
model.predict_dataloader = None
1882+
1883+
with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"):
1884+
trainer.predict(model)

0 commit comments

Comments
 (0)