Skip to content

Commit 3b0e4e0

Browse files
authored
Enable ZeRO tests for CI, fix to/half function calls (#6070)
* Enable ZeRO optimization, and make sure that the lightning module hook is called when we move to half precision * Added test, update to function
1 parent 97a81c3 commit 3b0e4e0

File tree

5 files changed

+88
-23
lines changed

5 files changed

+88
-23
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030
- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089))
3131

3232

33+
- Move lightning module to correct device type when using LightningDistributedWrapper ([#6070](https://github.com/PyTorchLightning/pytorch-lightning/pull/6070)
34+
35+
3336
## [1.2.0] - 2021-02-18
3437

3538
### Added

pytorch_lightning/overrides/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919

2020
from pytorch_lightning.core.lightning import LightningModule
2121
from pytorch_lightning.trainer.states import RunningStage
22+
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
2223
from pytorch_lightning.utilities.warnings import WarningCache
2324

2425
warning_cache = WarningCache()
2526

2627

27-
class _LightningModuleWrapperBase(torch.nn.Module):
28+
class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
2829

2930
def __init__(self, pl_module: LightningModule):
3031
"""
@@ -72,6 +73,9 @@ def forward(self, *inputs, **kwargs):
7273

7374
return output
7475

76+
def on_post_move_to_device(self):
77+
pass
78+
7579

7680
def warn_if_output_is_none(output: Any, method_name: str) -> None:
7781
""" Warns user about which method returned None. """

pytorch_lightning/utilities/device_dtype_mixin.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def to(self, *args, **kwargs) -> Module:
119119
self.__update_properties(device=out[0], dtype=out[1])
120120
return super().to(*args, **kwargs)
121121

122-
def cuda(self, device: Optional[int] = None) -> Module:
122+
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Module:
123123
"""Moves all model parameters and buffers to the GPU.
124124
This also makes associated parameters and buffers different objects. So
125125
it should be called before constructing optimizer if the module will
@@ -132,7 +132,8 @@ def cuda(self, device: Optional[int] = None) -> Module:
132132
Returns:
133133
Module: self
134134
"""
135-
self.__update_properties(device=torch.device('cuda', index=device))
135+
property_device = device if isinstance(device, torch.device) else torch.device('cuda', index=device)
136+
self.__update_properties(device=property_device)
136137
return super().cuda(device=device)
137138

138139
def cpu(self) -> Module:

tests/plugins/test_deepspeed_plugin.py

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,52 @@
88

99
from pytorch_lightning import Trainer
1010
from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin
11+
from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule
1112
from pytorch_lightning.utilities import _APEX_AVAILABLE, _DEEPSPEED_AVAILABLE, _NATIVE_AMP_AVAILABLE
1213
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1314
from tests.helpers.boring_model import BoringModel
1415

1516

17+
def test_deepspeed_lightning_module(tmpdir):
18+
"""
19+
Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves types and device correctly.
20+
"""
21+
22+
model = BoringModel()
23+
module = LightningDeepSpeedModule(model, precision=16)
24+
25+
module.half()
26+
assert module.dtype == torch.half
27+
assert model.dtype == torch.half
28+
29+
module.to(torch.double)
30+
assert module.dtype == torch.double
31+
assert model.dtype == torch.double
32+
33+
34+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
35+
def test_deepspeed_lightning_module_precision(tmpdir):
36+
"""
37+
Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves tensors to half when precision 16.
38+
"""
39+
40+
model = BoringModel()
41+
module = LightningDeepSpeedModule(model, precision=16)
42+
43+
module.cuda().half()
44+
assert module.dtype == torch.half
45+
assert model.dtype == torch.half
46+
47+
x = torch.randn((1, 32), dtype=torch.float).cuda()
48+
out = module(x)
49+
50+
assert out.dtype == torch.half
51+
52+
module.to(torch.double)
53+
assert module.dtype == torch.double
54+
assert model.dtype == torch.double
55+
56+
1657
@pytest.fixture
1758
def deepspeed_config():
1859
return {
@@ -34,6 +75,11 @@ def deepspeed_config():
3475
}
3576

3677

78+
@pytest.fixture
79+
def deepspeed_zero_config(deepspeed_config):
80+
return {**deepspeed_config, 'zero_allow_untested_optimizer': True, 'zero_optimization': {'stage': 2}}
81+
82+
3783
@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
3884
def test_deepspeed_plugin_string(tmpdir):
3985
"""
@@ -179,12 +225,7 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args
179225
return loss.backward()
180226

181227
model = TestModel()
182-
trainer = Trainer(
183-
fast_dev_run=True,
184-
default_root_dir=tmpdir,
185-
plugins=DeepSpeedPlugin(zero_optimization=False),
186-
gpus=1,
187-
)
228+
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, plugins=DeepSpeedPlugin(), gpus=1, precision=16)
188229
with pytest.warns(UserWarning, match='Overridden backward hook in the LightningModule will be ignored'):
189230
trainer.fit(model)
190231

@@ -203,17 +244,21 @@ def test_deepspeed_run_configure_optimizers(tmpdir):
203244
class TestModel(BoringModel):
204245

205246
def on_train_start(self) -> None:
206-
assert isinstance(self.trainer.optimizers[0], torch.optim.SGD)
247+
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
248+
249+
assert isinstance(self.trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer)
250+
assert isinstance(self.trainer.optimizers[0].optimizer, torch.optim.SGD)
207251
assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally
208252
# Ensure DeepSpeed engine has initialized with our optimizer/lr_scheduler
209253
assert isinstance(self.trainer.model.lr_scheduler, torch.optim.lr_scheduler.StepLR)
210254

211255
model = TestModel()
212256
trainer = Trainer(
213-
plugins=DeepSpeedPlugin(zero_optimization=False),
257+
plugins=DeepSpeedPlugin(), # disable ZeRO so our optimizers are not wrapped
214258
default_root_dir=tmpdir,
215259
gpus=1,
216260
fast_dev_run=True,
261+
precision=16
217262
)
218263

219264
trainer.fit(model)
@@ -226,7 +271,7 @@ def on_train_start(self) -> None:
226271
@pytest.mark.skipif(
227272
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
228273
)
229-
def test_deepspeed_config(tmpdir, deepspeed_config):
274+
def test_deepspeed_config(tmpdir, deepspeed_zero_config):
230275
"""
231276
Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers
232277
and saves the model weights to load correctly.
@@ -235,18 +280,22 @@ def test_deepspeed_config(tmpdir, deepspeed_config):
235280
class TestModel(BoringModel):
236281

237282
def on_train_start(self) -> None:
238-
import deepspeed
239-
assert isinstance(self.trainer.optimizers[0], torch.optim.SGD)
283+
from deepspeed.runtime.lr_schedules import WarmupLR
284+
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
285+
286+
assert isinstance(self.trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer)
287+
assert isinstance(self.trainer.optimizers[0].optimizer, torch.optim.SGD)
240288
assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally
241-
assert isinstance(self.trainer.model.optimizer, torch.optim.SGD)
242-
assert isinstance(self.trainer.model.lr_scheduler, deepspeed.runtime.lr_schedules.WarmupLR)
289+
# Ensure DeepSpeed engine has initialized with our optimizer/lr_scheduler
290+
assert isinstance(self.trainer.model.lr_scheduler, WarmupLR)
243291

244292
model = TestModel()
245293
trainer = Trainer(
246-
plugins=[DeepSpeedPlugin(config=deepspeed_config)],
294+
plugins=[DeepSpeedPlugin(config=deepspeed_zero_config)],
247295
default_root_dir=tmpdir,
248296
gpus=1,
249297
fast_dev_run=True,
298+
precision=16
250299
)
251300

252301
trainer.fit(model)
@@ -267,7 +316,7 @@ def test_deepspeed_multigpu(tmpdir, deepspeed_config):
267316
"""
268317
model = BoringModel()
269318
trainer = Trainer(
270-
plugins=[DeepSpeedPlugin(zero_optimization=False)],
319+
plugins=[DeepSpeedPlugin()],
271320
default_root_dir=tmpdir,
272321
gpus=2,
273322
fast_dev_run=True,
@@ -285,8 +334,9 @@ def _assert_save_model_is_equal(model, tmpdir, trainer):
285334
# carry out the check only on rank 0
286335
if trainer.global_rank == 0:
287336
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)
288-
saved_model = saved_model.float()
289-
model = model.float().cpu()
337+
if model.dtype == torch.half:
338+
saved_model = saved_model.half() # model is loaded in float32 as default, move it to float16
339+
model = model.cpu()
290340
# Assert model parameters are identical after loading
291341
for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()):
292342
assert torch.equal(orig_param, trained_model_param)

tests/utilities/test_dtype_device_mixin.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,19 @@ def test_submodules_multi_gpu_ddp_spawn(tmpdir):
101101
trainer.fit(model)
102102

103103

104+
@pytest.mark.parametrize(
105+
['device'],
106+
[
107+
pytest.param(None), # explicitly call without an index to see if the returning device contains an index
108+
pytest.param(0),
109+
pytest.param(torch.device('cuda', 0)),
110+
]
111+
)
104112
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
105-
def test_gpu_device_includes_index():
113+
def test_gpu_cuda_device(device):
106114
model = TopModule()
107115

108-
# explicitly call without an index to see if the returning device contains an index (it should!)
109-
model.cuda()
116+
model.cuda(device)
110117

111118
device = model.device
112119
assert device.type == 'cuda'

0 commit comments

Comments
 (0)