Skip to content

Commit dcd9dd8

Browse files
authored
Update docs for limit_predict_batches (#6507)
* add docs and minor updates * docs * fraction
1 parent b2bcad1 commit dcd9dd8

File tree

3 files changed

+16
-17
lines changed

3 files changed

+16
-17
lines changed

pytorch_lightning/trainer/data_loading.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,17 +293,17 @@ def _reset_eval_dataloader(
293293
if mode in modes and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler):
294294

295295
# when overfitting, the dataloader should not have sampler
296-
if self.overfit_batches > 0:
296+
if self.overfit_batches > 0 and mode != 'predict':
297297
rank_zero_warn(
298-
'You requested to overfit but enabled test/val dataloader shuffling.'
298+
'You requested to overfit but enabled val/test dataloader shuffling.'
299299
' We are turning it off for you.'
300300
)
301301
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset))
302302

303303
else:
304304
rank_zero_warn(
305305
f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn'
306-
' this off for validation and test dataloaders.'
306+
' this off for val/test/predict dataloaders.'
307307
)
308308

309309
if any([dl is None for dl in dataloaders]):

pytorch_lightning/trainer/predict_loop.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,16 @@ def __init__(self, trainer):
2828
def on_trainer_init(self):
2929
self.trainer.num_predict_batches = []
3030

31-
def get_predict_dataloaders(self, max_batches):
31+
def get_predict_dataloaders(self):
3232
self.trainer.reset_predict_dataloader(self.trainer.lightning_module)
3333

3434
dataloaders = self.trainer.predict_dataloaders
35-
if max_batches is None:
36-
max_batches = self.trainer.num_predict_batches
35+
max_batches = self.trainer.num_predict_batches
3736

3837
return dataloaders, max_batches
3938

40-
def should_skip_predict(self, dataloaders, max_batches):
41-
return dataloaders is None or not sum(max_batches)
39+
def should_skip_predict(self, max_batches):
40+
return sum(max_batches) == 0
4241

4342
def on_predict_model_eval(self, *_, **__):
4443
model_ref = self.trainer.lightning_module

pytorch_lightning/trainer/trainer.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,13 @@ def __init__(
198198
199199
gradient_clip_val: 0 means don't clip.
200200
201-
limit_train_batches: How much of training dataset to check (floats = percent, int = num_batches)
201+
limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches)
202202
203-
limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches)
203+
limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches)
204204
205-
limit_test_batches: How much of test dataset to check (floats = percent, int = num_batches)
205+
limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches)
206+
207+
limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches)
206208
207209
logger: Logger (or iterable collection of loggers) for experiment tracking.
208210
@@ -221,7 +223,7 @@ def __init__(
221223
222224
profiler: To profile individual steps during training and assist in identifying bottlenecks.
223225
224-
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int).
226+
overfit_batches: Overfit a fraction of training data (float) or a set number of batches (int).
225227
226228
plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
227229
@@ -754,10 +756,10 @@ def run_evaluate(self):
754756

755757
def run_predict(self):
756758
# prepare dataloaders
757-
dataloaders, max_batches = self.predict_loop.get_predict_dataloaders(None)
759+
dataloaders, max_batches = self.predict_loop.get_predict_dataloaders()
758760

759761
# check if we want to skip this evaluation
760-
if self.predict_loop.should_skip_predict(dataloaders, max_batches):
762+
if self.predict_loop.should_skip_predict(max_batches):
761763
return []
762764

763765
# ref model
@@ -922,9 +924,7 @@ def test(
922924

923925
# If you supply a datamodule you can't supply test_dataloaders
924926
if test_dataloaders and datamodule:
925-
raise MisconfigurationException(
926-
'You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`'
927-
)
927+
raise MisconfigurationException('You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`')
928928

929929
model_provided = model is not None
930930
model = model or self.lightning_module

0 commit comments

Comments
 (0)