Skip to content

Commit c48fc6a

Browse files
authored
[test] lr_find with bs_scale (#6422)
* init test: test_lr_find_with_bs_scale * Update test_lr_finder.py * remove gpu req * try boring model * custom boring model * pep8 * fix typo * Update test_lr_finder.py * typo * typo
1 parent b341b53 commit c48fc6a

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tests/trainer/test_lr_finder.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,3 +271,26 @@ def test_lr_finder_fails_fast_on_bad_config(tmpdir):
271271
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, auto_lr_find=True)
272272
with pytest.raises(MisconfigurationException, match='should have one of these fields'):
273273
trainer.tune(BoringModel())
274+
275+
276+
def test_lr_find_with_bs_scale(tmpdir):
277+
""" Test that lr_find runs with batch_size_scaling """
278+
279+
class BoringModelTune(BoringModel):
280+
def __init__(self, learning_rate=0.1, batch_size=2):
281+
super().__init__()
282+
self.save_hyperparameters()
283+
284+
model = BoringModelTune()
285+
before_lr = model.hparams.learning_rate
286+
287+
# logger file to get meta
288+
trainer = Trainer(
289+
default_root_dir=tmpdir,
290+
max_epochs=3,
291+
)
292+
bs = trainer.tuner.scale_batch_size(model)
293+
lr = trainer.tuner.lr_find(model).suggestion()
294+
295+
assert lr != before_lr
296+
assert isinstance(bs, int)

0 commit comments

Comments
 (0)