Skip to content

Commit 4b7c0fa

Browse files
authored
Fix amp autocast (#6080)
* precision fixes * add amp test model * fix test * revert * move assert to training step * fix test * fix test * remove unrelated changes * add changelog * remove unused import
1 parent 0b27147 commit 4b7c0fa

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121

2222
### Fixed
2323

24+
- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080))
25+
2426

2527
## [1.2.0] - 2021-02-18
2628

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,5 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
9191
@contextmanager
9292
def train_step_context(self) -> Generator[autocast, None, None]:
9393
"""Enable autocast context"""
94-
yield torch.cuda.amp.autocast()
94+
with torch.cuda.amp.autocast():
95+
yield

tests/models/test_amp.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,16 @@
2727
from tests.helpers import BoringModel
2828

2929

30+
class AMPTestModel(BoringModel):
31+
32+
def training_step(self, batch, batch_idx):
33+
assert torch.is_autocast_enabled()
34+
output = self(batch)
35+
assert output.dtype == torch.float16
36+
loss = self.loss(batch, output)
37+
return {"loss": loss}
38+
39+
3040
@pytest.mark.skip(reason='dp + amp not supported currently') # TODO
3141
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
3242
def test_amp_single_gpu_dp(tmpdir):
@@ -41,7 +51,7 @@ def test_amp_single_gpu_dp(tmpdir):
4151
precision=16,
4252
)
4353

44-
model = BoringModel()
54+
model = AMPTestModel()
4555
# tutils.run_model_test(trainer_options, model)
4656
trainer.fit(model)
4757

@@ -60,10 +70,9 @@ def test_amp_single_gpu_ddp_spawn(tmpdir):
6070
precision=16,
6171
)
6272

63-
model = BoringModel()
73+
model = AMPTestModel()
6474
# tutils.run_model_test(trainer_options, model)
6575
trainer.fit(model)
66-
6776
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
6877

6978

@@ -81,7 +90,7 @@ def test_amp_multi_gpu_dp(tmpdir):
8190
precision=16,
8291
)
8392

84-
model = BoringModel()
93+
model = AMPTestModel()
8594
# tutils.run_model_test(trainer_options, model)
8695
trainer.fit(model)
8796

@@ -100,10 +109,9 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir):
100109
precision=16,
101110
)
102111

103-
model = BoringModel()
112+
model = AMPTestModel()
104113
# tutils.run_model_test(trainer_options, model)
105114
trainer.fit(model)
106-
107115
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
108116

109117

@@ -122,7 +130,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
122130
# simulate setting slurm flags
123131
tutils.set_random_master_port()
124132

125-
model = BoringModel()
133+
model = AMPTestModel()
126134

127135
# exp file to get meta
128136
logger = tutils.get_default_logger(tmpdir)

0 commit comments

Comments
 (0)