@@ -198,11 +198,13 @@ def __init__(
198
198
199
199
gradient_clip_val: 0 means don't clip.
200
200
201
- limit_train_batches: How much of training dataset to check (floats = percent , int = num_batches)
201
+ limit_train_batches: How much of training dataset to check (float = fraction , int = num_batches)
202
202
203
- limit_val_batches: How much of validation dataset to check (floats = percent , int = num_batches)
203
+ limit_val_batches: How much of validation dataset to check (float = fraction , int = num_batches)
204
204
205
- limit_test_batches: How much of test dataset to check (floats = percent, int = num_batches)
205
+ limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches)
206
+
207
+ limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches)
206
208
207
209
logger: Logger (or iterable collection of loggers) for experiment tracking.
208
210
@@ -221,7 +223,7 @@ def __init__(
221
223
222
224
profiler: To profile individual steps during training and assist in identifying bottlenecks.
223
225
224
- overfit_batches: Overfit a percent of training data (float) or a set number of batches (int).
226
+ overfit_batches: Overfit a fraction of training data (float) or a set number of batches (int).
225
227
226
228
plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
227
229
@@ -754,10 +756,10 @@ def run_evaluate(self):
754
756
755
757
def run_predict (self ):
756
758
# prepare dataloaders
757
- dataloaders , max_batches = self .predict_loop .get_predict_dataloaders (None )
759
+ dataloaders , max_batches = self .predict_loop .get_predict_dataloaders ()
758
760
759
761
# check if we want to skip this evaluation
760
- if self .predict_loop .should_skip_predict (dataloaders , max_batches ):
762
+ if self .predict_loop .should_skip_predict (max_batches ):
761
763
return []
762
764
763
765
# ref model
@@ -922,9 +924,7 @@ def test(
922
924
923
925
# If you supply a datamodule you can't supply test_dataloaders
924
926
if test_dataloaders and datamodule :
925
- raise MisconfigurationException (
926
- 'You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`'
927
- )
927
+ raise MisconfigurationException ('You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`' )
928
928
929
929
model_provided = model is not None
930
930
model = model or self .lightning_module
0 commit comments