Skip to content

Commit 37f22c9

Browse files
kaushikb11carmocca
andauthored
Add trainer.predict config validation (#6543)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 634d831 commit 37f22c9

File tree

3 files changed

+59
-4
lines changed

3 files changed

+59
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4040
- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))
4141

4242

43-
- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))
43+
- Added `Trainer.predict` config validation ([#6543](https://github.com/PyTorchLightning/pytorch-lightning/pull/6543))
44+
4445

46+
- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))
4547

4648

4749
### Changed

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def verify_loop_configurations(self, model: LightningModule) -> None:
4040
self.__verify_eval_loop_configuration(model, 'val')
4141
elif self.trainer.state == TrainerState.TESTING:
4242
self.__verify_eval_loop_configuration(model, 'test')
43-
# TODO: add predict
43+
elif self.trainer.state == TrainerState.PREDICTING:
44+
self.__verify_predict_loop_configuration(model)
4445

4546
def __verify_train_loop_configuration(self, model):
4647
# -----------------------------------
@@ -99,3 +100,9 @@ def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) -
99100
rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop')
100101
if has_step and not has_loader:
101102
rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {stage} loop')
103+
104+
def __verify_predict_loop_configuration(self, model: LightningModule) -> None:
105+
106+
has_predict_dataloader = is_overridden('predict_dataloader', model)
107+
if not has_predict_dataloader:
108+
raise MisconfigurationException('Dataloader not found for `Trainer.predict`')

tests/trainer/test_config_validator.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import pytest
15+
import torch
1516

16-
from pytorch_lightning import Trainer
17+
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
1718
from pytorch_lightning.utilities.exceptions import MisconfigurationException
18-
from tests.helpers import BoringModel
19+
from tests.helpers import BoringModel, RandomDataset
1920

2021

2122
def test_wrong_train_setting(tmpdir):
@@ -101,3 +102,48 @@ def test_val_loop_config(tmpdir):
101102
model = BoringModel()
102103
model.validation_step = None
103104
trainer.validate(model)
105+
106+
107+
@pytest.mark.parametrize("datamodule", [False, True])
108+
def test_trainer_predict_verify_config(tmpdir, datamodule):
109+
110+
class TestModel(LightningModule):
111+
112+
def __init__(self):
113+
super().__init__()
114+
self.layer = torch.nn.Linear(32, 2)
115+
116+
def forward(self, x):
117+
return self.layer(x)
118+
119+
class TestLightningDataModule(LightningDataModule):
120+
121+
def __init__(self, dataloaders):
122+
super().__init__()
123+
self._dataloaders = dataloaders
124+
125+
def test_dataloader(self):
126+
return self._dataloaders
127+
128+
def predict_dataloader(self):
129+
return self._dataloaders
130+
131+
dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))]
132+
133+
model = TestModel()
134+
135+
trainer = Trainer(default_root_dir=tmpdir)
136+
137+
if datamodule:
138+
datamodule = TestLightningDataModule(dataloaders)
139+
results = trainer.predict(model, datamodule=datamodule)
140+
else:
141+
results = trainer.predict(model, dataloaders=dataloaders)
142+
143+
assert len(results) == 2
144+
assert results[0][0].shape == torch.Size([1, 2])
145+
146+
model.predict_dataloader = None
147+
148+
with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"):
149+
trainer.predict(model)

0 commit comments

Comments
 (0)