Skip to content

Commit 9fc733f

Browse files
awaelchlilexierule
authored andcommitted
Fix tuner.scale_batch_size not finding batch size attribute when using datamodule (#5968)
(cherry picked from commit b2bcad1)
1 parent 73ef543 commit 9fc733f

File tree

4 files changed

+94
-8
lines changed

4 files changed

+94
-8
lines changed

CHANGELOG.md

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111

1212
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))
1313

14+
- Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
1415

1516
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))
1617

@@ -21,9 +22,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2122
- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
2223

2324

25+
- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ([#4948](https://github.com/PyTorchLightning/pytorch-lightning/pull/4948))
26+
27+
2428
- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))
2529

2630

31+
- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277))
32+
33+
2734
- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))
2835

2936

@@ -49,6 +56,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4956

5057
### Deprecated
5158

59+
- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
60+
5261

5362
- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
5463

@@ -122,15 +131,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
122131
- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380))
123132

124133

125-
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))
134+
- Fixed DeepSpeed additional memory use on rank 0 when default device not set early enough ([#6460](https://github.com/PyTorchLightning/pytorch-lightning/pull/6460))
126135

127136

128-
- Fixed DeepSpeed additional memory use on rank 0 when default device not set early enough ([#6460](https://github.com/PyTorchLightning/pytorch-lightning/pull/6460))
137+
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))
129138

130139

131140
- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398))
132141

133142

143+
- Fixed an issue with `Tuner.scale_batch_size` not finding the batch size attribute in the datamodule ([#5968](https://github.com/PyTorchLightning/pytorch-lightning/pull/5968))
144+
145+
134146
## [1.2.3] - 2021-03-09
135147

136148
### Fixed
@@ -148,6 +160,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
148160
- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372))
149161

150162

163+
- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398))
164+
165+
151166
## [1.2.2] - 2021-03-02
152167

153168
### Added
@@ -169,9 +184,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
169184
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)
170185

171186

172-
- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324))
173-
174-
175187
## [1.2.1] - 2021-02-23
176188

177189
### Fixed

pytorch_lightning/trainer/training_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def on_train_start(self):
108108
# provide rank to profiler
109109
self.trainer.profile_connector.on_train_start(self.trainer)
110110

111-
def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
111+
def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None):
112112
# clean hparams
113113
if hasattr(model, "hparams"):
114114
parsing.clean_namespace(model.hparams)

pytorch_lightning/tuner/tuning.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,20 @@ def on_trainer_init(self, auto_lr_find, auto_scale_batch_size):
3232
self.trainer.auto_lr_find = auto_lr_find
3333
self.trainer.auto_scale_batch_size = auto_scale_batch_size
3434

35-
def tune(self, model, train_dataloader, val_dataloaders, datamodule):
35+
def setup_trainer(
36+
self,
37+
model: LightningModule,
38+
train_dataloader: Optional[DataLoader] = None,
39+
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
40+
datamodule: LightningDataModule = None,
41+
):
42+
self.trainer.model_connector.copy_trainer_model_properties(model)
3643
# setup data, etc...
3744
self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)
38-
3945
# hook
4046
self.trainer.data_connector.prepare_data(model)
4147

48+
def tune(self, model, train_dataloader, val_dataloaders, datamodule):
4249
# Run auto batch size scaling
4350
if self.trainer.auto_scale_batch_size:
4451
if isinstance(self.trainer.auto_scale_batch_size, bool):
@@ -101,6 +108,7 @@ def scale_batch_size(
101108
or datamodule.
102109
103110
"""
111+
self.setup_trainer(model, **fit_kwargs)
104112
return scale_batch_size(
105113
self.trainer,
106114
model,
@@ -125,6 +133,7 @@ def lr_find(
125133
datamodule: Optional[LightningDataModule] = None,
126134
update_attr: bool = False,
127135
):
136+
self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule)
128137
return lr_find(
129138
self.trainer,
130139
model,

tests/tuner/test_scale_batch_size.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
from torch.utils.data import DataLoader
16+
17+
from pytorch_lightning import Trainer
18+
from pytorch_lightning.tuner.tuning import Tuner
19+
from tests.helpers import BoringDataModule, BoringModel
20+
21+
22+
class BatchSizeDataModule(BoringDataModule):
23+
24+
def __init__(self, batch_size=None):
25+
super().__init__()
26+
if batch_size is not None:
27+
self.batch_size = batch_size
28+
29+
def train_dataloader(self):
30+
return DataLoader(self.random_train, batch_size=getattr(self, "batch_size", 1))
31+
32+
33+
class BatchSizeModel(BoringModel):
34+
35+
def __init__(self, batch_size=None):
36+
super().__init__()
37+
if batch_size is not None:
38+
self.batch_size = batch_size
39+
40+
41+
@pytest.mark.parametrize(
42+
"model,datamodule", [
43+
(BatchSizeModel(2), None),
44+
(BatchSizeModel(2), BatchSizeDataModule(2)),
45+
(BatchSizeModel(2), BatchSizeDataModule(None)),
46+
(BatchSizeModel(None), BatchSizeDataModule(2)),
47+
]
48+
)
49+
def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule):
50+
""" Test the tuner method `Tuner.scale_batch_size` with a datamodule. """
51+
trainer = Trainer(
52+
default_root_dir=tmpdir,
53+
limit_train_batches=1,
54+
limit_val_batches=0,
55+
max_epochs=1,
56+
)
57+
tuner = Tuner(trainer)
58+
new_batch_size = tuner.scale_batch_size(
59+
model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule
60+
)
61+
assert new_batch_size == 16
62+
if hasattr(model, "batch_size"):
63+
assert model.batch_size == 16
64+
if datamodule is not None and hasattr(datamodule, "batch_size"):
65+
assert datamodule.batch_size == 16

0 commit comments

Comments
 (0)