Skip to content

Update docs for limit_predict_batches #6507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,17 +293,17 @@ def _reset_eval_dataloader(
if mode in modes and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler):

# when overfitting, the dataloader should not have sampler
if self.overfit_batches > 0:
if self.overfit_batches > 0 and mode != 'predict':
rank_zero_warn(
'You requested to overfit but enabled test/val dataloader shuffling.'
'You requested to overfit but enabled val/test dataloader shuffling.'
' We are turning it off for you.'
)
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset))

else:
rank_zero_warn(
f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn'
' this off for validation and test dataloaders.'
' this off for val/test/predict dataloaders.'
)

if any([dl is None for dl in dataloaders]):
Expand Down
9 changes: 4 additions & 5 deletions pytorch_lightning/trainer/predict_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,16 @@ def __init__(self, trainer):
def on_trainer_init(self):
self.trainer.num_predict_batches = []

def get_predict_dataloaders(self, max_batches):
def get_predict_dataloaders(self):
self.trainer.reset_predict_dataloader(self.trainer.lightning_module)

dataloaders = self.trainer.predict_dataloaders
if max_batches is None:
max_batches = self.trainer.num_predict_batches
max_batches = self.trainer.num_predict_batches

return dataloaders, max_batches

def should_skip_predict(self, dataloaders, max_batches):
return dataloaders is None or not sum(max_batches)
def should_skip_predict(self, max_batches):
return sum(max_batches) == 0

def on_predict_model_eval(self, *_, **__):
model_ref = self.trainer.lightning_module
Expand Down
18 changes: 9 additions & 9 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,13 @@ def __init__(

gradient_clip_val: 0 means don't clip.

limit_train_batches: How much of training dataset to check (floats = percent, int = num_batches)
limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches)

limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches)
limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches)

limit_test_batches: How much of test dataset to check (floats = percent, int = num_batches)
limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches)

limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches)

logger: Logger (or iterable collection of loggers) for experiment tracking.

Expand All @@ -221,7 +223,7 @@ def __init__(

profiler: To profile individual steps during training and assist in identifying bottlenecks.

overfit_batches: Overfit a percent of training data (float) or a set number of batches (int).
overfit_batches: Overfit a fraction of training data (float) or a set number of batches (int).

plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.

Expand Down Expand Up @@ -754,10 +756,10 @@ def run_evaluate(self):

def run_predict(self):
# prepare dataloaders
dataloaders, max_batches = self.predict_loop.get_predict_dataloaders(None)
dataloaders, max_batches = self.predict_loop.get_predict_dataloaders()

# check if we want to skip this evaluation
if self.predict_loop.should_skip_predict(dataloaders, max_batches):
if self.predict_loop.should_skip_predict(max_batches):
return []

# ref model
Expand Down Expand Up @@ -922,9 +924,7 @@ def test(

# If you supply a datamodule you can't supply test_dataloaders
if test_dataloaders and datamodule:
raise MisconfigurationException(
'You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`'
)
raise MisconfigurationException('You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`')

model_provided = model is not None
model = model or self.lightning_module
Expand Down