Skip to content

Commit 634d831

Browse files
authored
Add AMP for validation, prediction and testing (#6565)
* Add Tests for val and test-steps * Add native AMP * pep8 tests * pep8 plugin * changelog
1 parent cb59039 commit 634d831

File tree

3 files changed

+48
-3
lines changed

3 files changed

+48
-3
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
116116

117117
### Fixed
118118

119+
- Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565))
120+
119121
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))
120122

121123

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,21 @@ def train_step_context(self) -> Generator[None, None, None]:
103103
"""Enable autocast context"""
104104
with torch.cuda.amp.autocast():
105105
yield
106+
107+
@contextmanager
108+
def val_step_context(self) -> Generator[None, None, None]:
109+
"""Enable autocast context"""
110+
with torch.cuda.amp.autocast():
111+
yield
112+
113+
@contextmanager
114+
def test_step_context(self) -> Generator[None, None, None]:
115+
"""Enable autocast context"""
116+
with torch.cuda.amp.autocast():
117+
yield
118+
119+
@contextmanager
120+
def predict_context(self) -> Generator[None, None, None]:
121+
"""Enable autocast context"""
122+
with torch.cuda.amp.autocast():
123+
yield

tests/models/test_amp.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,43 @@
1717
import pytest
1818
import torch
1919
from torch import optim
20+
from torch.utils.data import DataLoader
2021

2122
import tests.helpers.utils as tutils
2223
from pytorch_lightning import Trainer
2324
from pytorch_lightning.plugins.environments import SLURMEnvironment
2425
from pytorch_lightning.trainer.states import TrainerState
2526
from pytorch_lightning.utilities.exceptions import MisconfigurationException
26-
from tests.helpers import BoringModel
27+
from tests.helpers import BoringModel, RandomDataset
2728
from tests.helpers.runif import RunIf
2829

2930

3031
class AMPTestModel(BoringModel):
3132

32-
def training_step(self, batch, batch_idx):
33+
def _step(self, batch, batch_idx):
3334
assert torch.is_autocast_enabled()
3435
output = self(batch)
3536
assert output.dtype == torch.float16
3637
loss = self.loss(batch, output)
37-
return {"loss": loss}
38+
return loss
39+
40+
def training_step(self, batch, batch_idx):
41+
output = self._step(batch, batch_idx)
42+
return {"loss": output}
43+
44+
def validation_step(self, batch, batch_idx):
45+
output = self._step(batch, batch_idx)
46+
return {"x": output}
47+
48+
def test_step(self, batch, batch_idx):
49+
output = self._step(batch, batch_idx)
50+
return {"y": output}
51+
52+
def predict(self, batch, batch_idx, dataloader_idx=None):
53+
assert torch.is_autocast_enabled()
54+
output = self(batch)
55+
assert output.dtype == torch.float16
56+
return output
3857

3958

4059
@pytest.mark.skip(reason='dp + amp not supported currently') # TODO
@@ -54,6 +73,8 @@ def test_amp_single_gpu_dp(tmpdir):
5473
model = AMPTestModel()
5574
# tutils.run_model_test(trainer_options, model)
5675
trainer.fit(model)
76+
trainer.test(model)
77+
trainer.predict(model, DataLoader(RandomDataset(32, 64)))
5778

5879
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
5980

@@ -73,6 +94,8 @@ def test_amp_single_gpu_ddp_spawn(tmpdir):
7394
model = AMPTestModel()
7495
# tutils.run_model_test(trainer_options, model)
7596
trainer.fit(model)
97+
trainer.test(model)
98+
trainer.predict(model, DataLoader(RandomDataset(32, 64)))
7699
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
77100

78101

@@ -112,6 +135,8 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir):
112135
model = AMPTestModel()
113136
# tutils.run_model_test(trainer_options, model)
114137
trainer.fit(model)
138+
trainer.test(model)
139+
trainer.predict(model, DataLoader(RandomDataset(32, 64)))
115140
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
116141

117142

0 commit comments

Comments
 (0)